44package api
55
66import (
7- "context"
87 "encoding/json"
98 "fmt"
109 "net/http"
1110 "strings"
1211
1312 grpcTokenV1 "github.com/absmach/supermq/api/grpc/token/v1"
14- smqauth "github.com/absmach/supermq/auth"
1513 "github.com/absmach/supermq/pkg/oauth2"
1614 "github.com/absmach/supermq/users"
15+ useroauth "github.com/absmach/supermq/users/oauth"
1716 "github.com/go-chi/chi/v5"
18- goauth2 "golang.org/x/oauth2"
1917)
2018
2119var (
@@ -43,18 +41,18 @@ type authURLResponse struct {
4341// - GET /oauth/authorize/{provider} - Returns the authorization URL
4442// - GET /oauth/callback/{provider} - Handles OAuth2 callback and sets cookies
4543// - POST /oauth/cli/callback/{provider} - Handles CLI OAuth2 callback and returns JSON.
46- func oauthHandler (r * chi.Mux , svc users.Service , tokenClient grpcTokenV1.TokenServiceClient , deviceStore DeviceCodeStore , providers ... oauth2.Provider ) * chi.Mux {
44+ func oauthHandler (r * chi.Mux , svc users.Service , tokenClient grpcTokenV1.TokenServiceClient , oauthSvc useroauth. Service , providers ... oauth2.Provider ) * chi.Mux {
4745 for _ , provider := range providers {
48- r .HandleFunc ("/oauth/callback/" + provider .Name (), oauth2CallbackHandler (provider , svc , tokenClient , deviceStore ))
46+ r .HandleFunc ("/oauth/callback/" + provider .Name (), oauth2CallbackHandler (provider , oauthSvc ))
4947 r .Get ("/oauth/authorize/" + provider .Name (), oauth2AuthorizeHandler (provider ))
50- r .Post ("/oauth/cli/callback/" + provider .Name (), oauth2CLICallbackHandler (provider , svc , tokenClient ))
48+ r .Post ("/oauth/cli/callback/" + provider .Name (), oauth2CLICallbackHandler (provider , oauthSvc ))
5149 }
5250
5351 return r
5452}
5553
5654// oauth2CallbackHandler is a http.HandlerFunc that handles OAuth2 callbacks.
57- func oauth2CallbackHandler (oauth oauth2.Provider , svc users .Service , tokenClient grpcTokenV1. TokenServiceClient , deviceStore DeviceCodeStore ) http.HandlerFunc {
55+ func oauth2CallbackHandler (oauth oauth2.Provider , oauthSvc useroauth .Service ) http.HandlerFunc {
5856 return func (w http.ResponseWriter , r * http.Request ) {
5957 if ! oauth .IsEnabled () {
6058 redirectWithError (w , r , oauth .ErrorURL (), "oauth provider is disabled" )
@@ -64,8 +62,8 @@ func oauth2CallbackHandler(oauth oauth2.Provider, svc users.Service, tokenClient
6462 state := r .FormValue ("state" )
6563
6664 // Check if this is a device flow callback (state contains device: prefix)
67- if strings . HasPrefix (state , "device:" ) {
68- handleDeviceFlowCallback (w , r , oauth , svc , tokenClient , deviceStore )
65+ if useroauth . IsDeviceFlowState (state ) {
66+ handleDeviceFlowCallback (w , r , oauth , oauthSvc )
6967 return
7068 }
7169
@@ -80,13 +78,7 @@ func oauth2CallbackHandler(oauth oauth2.Provider, svc users.Service, tokenClient
8078 return
8179 }
8280
83- token , err := oauth .Exchange (r .Context (), code )
84- if err != nil {
85- redirectWithError (w , r , oauth .ErrorURL (), err .Error ())
86- return
87- }
88-
89- jwt , err := processOAuthUser (r .Context (), oauth , token .AccessToken , svc , tokenClient )
81+ jwt , err := oauthSvc .ProcessWebCallback (r .Context (), oauth , code , "" )
9082 if err != nil {
9183 redirectWithError (w , r , oauth .ErrorURL (), err .Error ())
9284 return
@@ -125,7 +117,7 @@ func oauth2AuthorizeHandler(oauth oauth2.Provider) http.HandlerFunc {
125117}
126118
127119// oauth2CLICallbackHandler handles OAuth2 callbacks for CLI and returns JSON tokens.
128- func oauth2CLICallbackHandler (oauth oauth2.Provider , svc users .Service , tokenClient grpcTokenV1. TokenServiceClient ) http.HandlerFunc {
120+ func oauth2CLICallbackHandler (oauth oauth2.Provider , oauthSvc useroauth .Service ) http.HandlerFunc {
129121 return func (w http.ResponseWriter , r * http.Request ) {
130122 w .Header ().Set ("Content-Type" , "application/json" )
131123
@@ -155,13 +147,7 @@ func oauth2CLICallbackHandler(oauth oauth2.Provider, svc users.Service, tokenCli
155147 return
156148 }
157149
158- token , err := exchangeCode (r .Context (), oauth , req .Code , req .RedirectURL )
159- if err != nil {
160- respondWithJSON (w , http .StatusUnauthorized , newErrorResponse (err .Error ()))
161- return
162- }
163-
164- jwt , err := processOAuthUser (r .Context (), oauth , token .AccessToken , svc , tokenClient )
150+ jwt , err := oauthSvc .ProcessWebCallback (r .Context (), oauth , req .Code , req .RedirectURL )
165151 if err != nil {
166152 status := http .StatusInternalServerError
167153 if err .Error () == "unauthorized" {
@@ -176,45 +162,6 @@ func oauth2CLICallbackHandler(oauth oauth2.Provider, svc users.Service, tokenCli
176162 }
177163}
178164
179- // exchangeCode exchanges an authorization code for an access token.
180- // If redirectURL is provided, it uses ExchangeWithRedirect, otherwise uses Exchange.
181- func exchangeCode (ctx context.Context , provider oauth2.Provider , code , redirectURL string ) (goauth2.Token , error ) {
182- if redirectURL != "" {
183- return provider .ExchangeWithRedirect (ctx , code , redirectURL )
184- }
185- return provider .Exchange (ctx , code )
186- }
187-
188- // processOAuthUser retrieves user info from the OAuth provider, creates or updates the user,
189- // adds user policies, and issues a JWT token.
190- func processOAuthUser (ctx context.Context , provider oauth2.Provider , accessToken string , svc users.Service , tokenClient grpcTokenV1.TokenServiceClient ) (* grpcTokenV1.Token , error ) {
191- user , err := provider .UserInfo (accessToken )
192- if err != nil {
193- return nil , err
194- }
195-
196- user .AuthProvider = provider .Name ()
197- if user .AuthProvider == "" {
198- user .AuthProvider = "oauth"
199- }
200-
201- user , err = svc .OAuthCallback (ctx , user )
202- if err != nil {
203- return nil , err
204- }
205-
206- if err := svc .OAuthAddUserPolicy (ctx , user ); err != nil {
207- return nil , err
208- }
209-
210- return tokenClient .Issue (ctx , & grpcTokenV1.IssueReq {
211- UserId : user .ID ,
212- Type : uint32 (smqauth .AccessKey ),
213- UserRole : uint32 (smqauth .UserRole ),
214- Verified : ! user .VerifiedAt .IsZero (),
215- })
216- }
217-
218165// respondWithJSON writes a JSON response with the given status code and data.
219166func respondWithJSON (w http.ResponseWriter , status int , data any ) {
220167 w .WriteHeader (status )
@@ -248,13 +195,13 @@ func setTokenCookies(w http.ResponseWriter, jwt *grpcTokenV1.Token) {
248195}
249196
250197// handleDeviceFlowCallback processes OAuth callback for device authorization flow.
251- func handleDeviceFlowCallback (w http.ResponseWriter , r * http.Request , oauth oauth2.Provider , svc users .Service , tokenClient grpcTokenV1. TokenServiceClient , deviceStore DeviceCodeStore ) {
198+ func handleDeviceFlowCallback (w http.ResponseWriter , r * http.Request , oauth oauth2.Provider , oauthSvc useroauth .Service ) {
252199 // Extract user code from state (format: "device:ABCD-EFGH")
253200 state := r .FormValue ("state" )
254- userCode := strings . TrimPrefix (state , "device:" )
201+ userCode := useroauth . ExtractUserCodeFromState (state )
255202
256- // Get device code by user code
257- deviceCode , err := deviceStore . GetByUserCode ( userCode )
203+ // Get device code by user code to validate it exists
204+ _ , err := oauthSvc . GetDeviceCodeByUserCode ( r . Context (), userCode )
258205 if err != nil {
259206 w .Header ().Set ("Content-Type" , "text/html" )
260207 fmt .Fprint (w , strings .Replace (errorHTML , "{{ERROR_MESSAGE}}" , "The device code is invalid or has expired." , 1 ))
@@ -269,20 +216,10 @@ func handleDeviceFlowCallback(w http.ResponseWriter, r *http.Request, oauth oaut
269216 return
270217 }
271218
272- // Exchange OAuth code for token
273- token , err := oauth .Exchange (r .Context (), code )
274- if err != nil {
275- w .Header ().Set ("Content-Type" , "text/html" )
276- fmt .Fprint (w , strings .Replace (errorHTML , "{{ERROR_MESSAGE}}" , fmt .Sprintf ("Failed to exchange code: %s." , err .Error ()), 1 ))
277- return
278- }
279-
280- // Mark device code as approved with access token
281- deviceCode .Approved = true
282- deviceCode .AccessToken = token .AccessToken
283- if err := deviceStore .Update (deviceCode ); err != nil {
219+ // Process the device callback
220+ if err := oauthSvc .ProcessDeviceCallback (r .Context (), oauth , userCode , code ); err != nil {
284221 w .Header ().Set ("Content-Type" , "text/html" )
285- fmt .Fprint (w , strings .Replace (errorHTML , "{{ERROR_MESSAGE}}" , fmt .Sprintf ("Failed to approve device : %s." , err .Error ()), 1 ))
222+ fmt .Fprint (w , strings .Replace (errorHTML , "{{ERROR_MESSAGE}}" , fmt .Sprintf ("Failed to process callback : %s." , err .Error ()), 1 ))
286223 return
287224 }
288225
0 commit comments