You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
148 lines
3.8 KiB
Go
148 lines
3.8 KiB
Go
// Copyright 2018 Drone.IO Inc. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package oauth1
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"encoding/base64"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"git.awesome-for.me/liuzhiguo/go-scm/scm"
|
|
"git.awesome-for.me/liuzhiguo/go-scm/scm/transport/internal"
|
|
)
|
|
|
|
// clock provides a interface for current time providers. A Clock can be used
|
|
// in place of calling time.Now() directly.
|
|
type clock interface {
|
|
Now() time.Time
|
|
}
|
|
|
|
// A noncer provides random nonce strings.
|
|
type noncer interface {
|
|
Nonce() string
|
|
}
|
|
|
|
// Transport is an http.RoundTripper that refreshes oauth
|
|
// tokens, wrapping a base RoundTripper and refreshing the
|
|
// token if expired.
|
|
type Transport struct {
|
|
// Consumer Key
|
|
ConsumerKey string
|
|
|
|
// Consumer Private Key
|
|
PrivateKey *rsa.PrivateKey
|
|
|
|
// Source supplies the Token to add to the request
|
|
// Authorization headers.
|
|
Source scm.TokenSource
|
|
|
|
// Base is the base RoundTripper used to make requests.
|
|
// If nil, http.DefaultTransport is used.
|
|
Base http.RoundTripper
|
|
|
|
noncer noncer
|
|
clock clock
|
|
}
|
|
|
|
// RoundTrip authorizes and authenticates the request with
|
|
// an access token from the request context.
|
|
func (t *Transport) RoundTrip(r *http.Request) (*http.Response, error) {
|
|
ctx := r.Context()
|
|
token, err := t.Source.Token(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if token == nil {
|
|
return t.base().RoundTrip(r)
|
|
}
|
|
r2 := internal.CloneRequest(r)
|
|
err = t.setRequestAuthHeader(r2, token)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return t.base().RoundTrip(r2)
|
|
}
|
|
|
|
// base returns the base transport. If no base transport
|
|
// is configured, the default transport is returned.
|
|
func (t *Transport) base() http.RoundTripper {
|
|
if t.Base != nil {
|
|
return t.Base
|
|
}
|
|
return http.DefaultTransport
|
|
}
|
|
|
|
// setRequestAuthHeader sets the OAuth1 header for making
|
|
// authenticated requests with an AccessToken according to
|
|
// RFC 5849 3.1.
|
|
func (t *Transport) setRequestAuthHeader(r *http.Request, token *scm.Token) error {
|
|
oauthParams := t.commonOAuthParams()
|
|
oauthParams["oauth_token"] = token.Token
|
|
params := collectParameters(r, oauthParams)
|
|
|
|
signatureBase := signatureBase(r, params)
|
|
signature, err := sign(t.PrivateKey, signatureBase)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
oauthParams["oauth_signature"] = signature
|
|
r.Header.Set("Authorization", authHeaderValue(oauthParams))
|
|
return nil
|
|
}
|
|
|
|
// commonOAuthParams returns a map of the common OAuth1
|
|
// protocol parameters, excluding the oauth_signature.
|
|
func (t *Transport) commonOAuthParams() map[string]string {
|
|
return map[string]string{
|
|
"oauth_consumer_key": t.ConsumerKey,
|
|
"oauth_signature_method": "RSA-SHA1",
|
|
"oauth_timestamp": strconv.FormatInt(t.epoch(), 10),
|
|
"oauth_nonce": t.nonce(),
|
|
"oauth_version": "1.0",
|
|
}
|
|
}
|
|
|
|
// Returns a base64 encoded random 32 byte string.
|
|
func (t *Transport) nonce() string {
|
|
if t.noncer != nil {
|
|
return t.noncer.Nonce()
|
|
}
|
|
b := make([]byte, 32)
|
|
rand.Read(b)
|
|
return base64.StdEncoding.EncodeToString(b)
|
|
}
|
|
|
|
// Returns the Unix epoch seconds.
|
|
func (t *Transport) epoch() int64 {
|
|
if t.clock != nil {
|
|
return t.clock.Now().Unix()
|
|
}
|
|
return time.Now().Unix()
|
|
}
|
|
|
|
// authHeaderValue formats OAuth parameters according to
|
|
// RFC 5849 3.5.1.
|
|
func authHeaderValue(oauthParams map[string]string) string {
|
|
pairs := sortParameters(encodeParameters(oauthParams), `%s="%s"`)
|
|
return "OAuth " + strings.Join(pairs, ", ")
|
|
}
|
|
|
|
// collectParameters returns a map of request parameter keys
|
|
// and values as defined in RFC 5849 3.4.1.3.
|
|
func collectParameters(r *http.Request, oauthParams map[string]string) map[string]string {
|
|
params := map[string]string{}
|
|
for key, value := range r.URL.Query() {
|
|
params[key] = value[0]
|
|
}
|
|
for key, value := range oauthParams {
|
|
params[key] = value
|
|
}
|
|
return params
|
|
}
|