You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
138 lines
3.1 KiB
Go
138 lines
3.1 KiB
Go
// Copyright 2018 Drone.IO Inc. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package oauth2
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"git.awesome-for.me/liuzhiguo/go-scm/scm"
|
|
)
|
|
|
|
// expiryDelta determines how earlier a token should be considered
|
|
// expired than its actual expiration time. It is used to avoid late
|
|
// expirations due to client-server time mismatches.
|
|
const expiryDelta = time.Minute
|
|
|
|
// Refresher is an http.RoundTripper that refreshes oauth
|
|
// tokens, wrapping a base RoundTripper and refreshing the
|
|
// token if expired.
|
|
//
|
|
// IMPORTANT the Refresher is NOT safe for concurrent use
|
|
// by multiple goroutines.
|
|
type Refresher struct {
|
|
ClientID string
|
|
ClientSecret string
|
|
Endpoint string
|
|
|
|
Source scm.TokenSource
|
|
Client *http.Client
|
|
}
|
|
|
|
// Token returns a token. If the token is missing or
|
|
// expired, the token is refreshed.
|
|
func (t *Refresher) Token(ctx context.Context) (*scm.Token, error) {
|
|
token, err := t.Source.Token(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if !expired(token) {
|
|
return token, nil
|
|
}
|
|
err = t.Refresh(token)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return token, nil
|
|
}
|
|
|
|
// Refresh refreshes the expired token.
|
|
func (t *Refresher) Refresh(token *scm.Token) error {
|
|
values := url.Values{}
|
|
values.Set("grant_type", "refresh_token")
|
|
values.Set("refresh_token", token.Refresh)
|
|
|
|
reader := strings.NewReader(
|
|
values.Encode(),
|
|
)
|
|
req, err := http.NewRequest("POST", t.Endpoint, reader)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
req.SetBasicAuth(t.ClientID, t.ClientSecret)
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
|
|
res, err := t.client().Do(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer res.Body.Close()
|
|
|
|
if res.StatusCode > 299 {
|
|
out := new(tokenError)
|
|
err = json.NewDecoder(res.Body).Decode(out)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return out
|
|
}
|
|
|
|
out := new(tokenGrant)
|
|
err = json.NewDecoder(res.Body).Decode(out)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
token.Token = out.Access
|
|
token.Refresh = out.Refresh
|
|
token.Expires = time.Now().Add(
|
|
time.Duration(out.Expires) * time.Second,
|
|
)
|
|
return nil
|
|
}
|
|
|
|
// client returns the http transport. If no base client
|
|
// is configured, the default client is returned.
|
|
func (t *Refresher) client() *http.Client {
|
|
if t.Client != nil {
|
|
return t.Client
|
|
}
|
|
return http.DefaultClient
|
|
}
|
|
|
|
// expired reports whether the token is expired.
|
|
func expired(token *scm.Token) bool {
|
|
if len(token.Refresh) == 0 {
|
|
return false
|
|
}
|
|
if token.Expires.IsZero() && len(token.Token) != 0 {
|
|
return false
|
|
}
|
|
return token.Expires.Add(-expiryDelta).
|
|
Before(time.Now())
|
|
}
|
|
|
|
// tokenGrant is the token returned by the token endpoint.
|
|
type tokenGrant struct {
|
|
Access string `json:"access_token"`
|
|
Refresh string `json:"refresh_token"`
|
|
Expires int64 `json:"expires_in"`
|
|
}
|
|
|
|
// tokenError is the error returned when the token endpoint
|
|
// returns a non-2XX HTTP status code.
|
|
type tokenError struct {
|
|
Code string `json:"error"`
|
|
Message string `json:"error_description"`
|
|
}
|
|
|
|
func (t *tokenError) Error() string {
|
|
return t.Message
|
|
}
|