add oauth1 and oauth2 transports with context sources
parent
7ef3597551
commit
709496f6e5
@ -0,0 +1,34 @@
|
||||
// Copyright 2018 Drone.IO Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package scm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
type (
|
||||
// Token represents the credentials used to authorize
|
||||
// the requests to access protected resources.
|
||||
Token struct {
|
||||
Token string
|
||||
Refresh string
|
||||
Expires time.Time
|
||||
}
|
||||
|
||||
// TokenSource returns a token.
|
||||
TokenSource interface {
|
||||
Token(context.Context) (*Token, error)
|
||||
}
|
||||
|
||||
// TokenKey is the key to use with the context.WithValue
|
||||
// function to associate an Token value with a context.
|
||||
TokenKey struct{}
|
||||
)
|
||||
|
||||
// WithContext returns a copy of parent in which the token value is set
|
||||
func WithContext(parent context.Context, token *Token) context.Context {
|
||||
return context.WithValue(parent, TokenKey{}, token)
|
||||
}
|
@ -0,0 +1,22 @@
|
||||
// Copyright 2018 Drone.IO Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package internal
|
||||
|
||||
import "net/http"
|
||||
|
||||
// CloneRequest returns a clone of the provided
|
||||
// http.Request. The clone is a shallow copy of the struct
|
||||
// and its Header map.
|
||||
func CloneRequest(r *http.Request) *http.Request {
|
||||
// shallow copy of the struct
|
||||
r2 := new(http.Request)
|
||||
*r2 = *r
|
||||
// deep copy of the Header
|
||||
r2.Header = make(http.Header, len(r.Header))
|
||||
for k, s := range r.Header {
|
||||
r2.Header[k] = append([]string(nil), s...)
|
||||
}
|
||||
return r2
|
||||
}
|
@ -0,0 +1,28 @@
|
||||
// Copyright 2018 Drone.IO Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestCloneRequest(t *testing.T) {
|
||||
b := new(bytes.Buffer)
|
||||
r1, _ := http.NewRequest("GET", "http://example.com", b)
|
||||
r1.Header.Set("Accept", "application/json")
|
||||
r1.Header.Set("Etag", "1")
|
||||
r2 := CloneRequest(r1)
|
||||
if r1 == r2 {
|
||||
t.Errorf("Expect http.Request cloned")
|
||||
}
|
||||
if diff := cmp.Diff(r1.Header, r2.Header); diff != "" {
|
||||
t.Errorf("Expect http.Header cloned")
|
||||
t.Log(diff)
|
||||
}
|
||||
}
|
@ -0,0 +1,60 @@
|
||||
// Copyright (c) 2015 Dalton Hubble. All rights reserved.
|
||||
// Copyrights licensed under the MIT License.
|
||||
|
||||
package oauth1
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// encodeParameterString encodes collected OAuth parameters
|
||||
// into a parameter string as defined in RFC 5894 3.4.1.3.2.
|
||||
func encodeParameterString(params map[string]string) string {
|
||||
return strings.Join(sortParameters(
|
||||
encodeParameters(params), "%s=%s"), "&")
|
||||
}
|
||||
|
||||
// encodeParameters percent encodes parameter keys and
|
||||
// values according to RFC5849 3.6 and RFC3986 2.1 and
|
||||
// returns a new map.
|
||||
func encodeParameters(params map[string]string) map[string]string {
|
||||
encoded := map[string]string{}
|
||||
for key, value := range params {
|
||||
encoded[percentEncode(key)] = percentEncode(value)
|
||||
}
|
||||
return encoded
|
||||
}
|
||||
|
||||
// percentEncode percent encodes a string according to
|
||||
// RFC 3986 2.1.
|
||||
func percentEncode(input string) string {
|
||||
var buf bytes.Buffer
|
||||
for _, b := range []byte(input) {
|
||||
// if in unreserved set
|
||||
if shouldEscape(b) {
|
||||
buf.Write([]byte(fmt.Sprintf("%%%02X", b)))
|
||||
} else {
|
||||
// do not escape, write byte as-is
|
||||
buf.WriteByte(b)
|
||||
}
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// shouldEscape returns false if the byte is an unreserved
|
||||
// character that should not be escaped and true otherwise,
|
||||
// according to RFC 3986 2.1.
|
||||
func shouldEscape(c byte) bool {
|
||||
// RFC3986 2.3 unreserved characters
|
||||
if 'A' <= c && c <= 'Z' || 'a' <= c && c <= 'z' || '0' <= c && c <= '9' {
|
||||
return false
|
||||
}
|
||||
switch c {
|
||||
case '-', '.', '_', '~':
|
||||
return false
|
||||
}
|
||||
// all other bytes must be escaped
|
||||
return true
|
||||
}
|
@ -0,0 +1,60 @@
|
||||
// Copyright (c) 2015 Dalton Hubble. All rights reserved.
|
||||
// Copyrights licensed under the MIT License.
|
||||
|
||||
package oauth1
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestEncodeParameterString(t *testing.T) {
|
||||
params := map[string]string{
|
||||
"key 1": "key 2",
|
||||
"key+3": "key+4",
|
||||
}
|
||||
want := "key%201=key%202&key%2B3=key%2B4"
|
||||
got := encodeParameterString(params)
|
||||
if got != want {
|
||||
t.Errorf("Want encoded string %s, got %s", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeParameters(t *testing.T) {
|
||||
params := map[string]string{
|
||||
"key 1": "key 2",
|
||||
"key+3": "key+4",
|
||||
}
|
||||
want := map[string]string{
|
||||
"key%201": "key%202",
|
||||
"key%2B3": "key%2B4",
|
||||
}
|
||||
got := encodeParameters(params)
|
||||
if diff := cmp.Diff(got, want); diff != "" {
|
||||
t.Errorf("Unexpected Results")
|
||||
t.Log(diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPercentEncode(t *testing.T) {
|
||||
cases := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{" ", "%20"},
|
||||
{"%", "%25"},
|
||||
{"&", "%26"},
|
||||
{"-._", "-._"},
|
||||
{" /=+", "%20%2F%3D%2B"},
|
||||
{"Ladies + Gentlemen", "Ladies%20%2B%20Gentlemen"},
|
||||
{"An encoded string!", "An%20encoded%20string%21"},
|
||||
{"Dogs, Cats & Mice", "Dogs%2C%20Cats%20%26%20Mice"},
|
||||
{"☃", "%E2%98%83"},
|
||||
}
|
||||
for _, c := range cases {
|
||||
if output := percentEncode(c.input); output != c.expected {
|
||||
t.Errorf("expected %s, got %s", c.expected, output)
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,147 @@
|
||||
// Copyright 2018 Drone.IO Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package oauth1
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/drone/go-scm/scm"
|
||||
"github.com/drone/go-scm/scm/transport/internal"
|
||||
)
|
||||
|
||||
// clock provides a interface for current time providers. A Clock can be used
|
||||
// in place of calling time.Now() directly.
|
||||
type clock interface {
|
||||
Now() time.Time
|
||||
}
|
||||
|
||||
// A noncer provides random nonce strings.
|
||||
type noncer interface {
|
||||
Nonce() string
|
||||
}
|
||||
|
||||
// Transport is an http.RoundTripper that refreshes oauth
|
||||
// tokens, wrapping a base RoundTripper and refreshing the
|
||||
// token if expired.
|
||||
type Transport struct {
|
||||
// Consumer Key
|
||||
ConsumerKey string
|
||||
|
||||
// Consumer Private Key
|
||||
PrivateKey *rsa.PrivateKey
|
||||
|
||||
// Source supplies the Token to add to the request
|
||||
// Authorization headers.
|
||||
Source scm.TokenSource
|
||||
|
||||
// Base is the base RoundTripper used to make requests.
|
||||
// If nil, http.DefaultTransport is used.
|
||||
Base http.RoundTripper
|
||||
|
||||
noncer noncer
|
||||
clock clock
|
||||
}
|
||||
|
||||
// RoundTrip authorizes and authenticates the request with
|
||||
// an access token from the request context.
|
||||
func (t *Transport) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
ctx := r.Context()
|
||||
token, err := t.Source.Token(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if token == nil {
|
||||
return t.base().RoundTrip(r)
|
||||
}
|
||||
r2 := internal.CloneRequest(r)
|
||||
err = t.setRequestAuthHeader(r2, token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return t.base().RoundTrip(r2)
|
||||
}
|
||||
|
||||
// base returns the base transport. If no base transport
|
||||
// is configured, the default transport is returned.
|
||||
func (t *Transport) base() http.RoundTripper {
|
||||
if t.Base != nil {
|
||||
return t.Base
|
||||
}
|
||||
return http.DefaultTransport
|
||||
}
|
||||
|
||||
// setRequestAuthHeader sets the OAuth1 header for making
|
||||
// authenticated requests with an AccessToken according to
|
||||
// RFC 5849 3.1.
|
||||
func (t *Transport) setRequestAuthHeader(r *http.Request, token *scm.Token) error {
|
||||
oauthParams := t.commonOAuthParams()
|
||||
oauthParams["oauth_token"] = token.Token
|
||||
params := collectParameters(r, oauthParams)
|
||||
|
||||
signatureBase := signatureBase(r, params)
|
||||
signature, err := sign(t.PrivateKey, signatureBase)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
oauthParams["oauth_signature"] = signature
|
||||
r.Header.Set("Authorization", authHeaderValue(oauthParams))
|
||||
return nil
|
||||
}
|
||||
|
||||
// commonOAuthParams returns a map of the common OAuth1
|
||||
// protocol parameters, excluding the oauth_signature.
|
||||
func (t *Transport) commonOAuthParams() map[string]string {
|
||||
return map[string]string{
|
||||
"oauth_consumer_key": t.ConsumerKey,
|
||||
"oauth_signature_method": "RSA-SHA1",
|
||||
"oauth_timestamp": strconv.FormatInt(t.epoch(), 10),
|
||||
"oauth_nonce": t.nonce(),
|
||||
"oauth_version": "1.0",
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a base64 encoded random 32 byte string.
|
||||
func (t *Transport) nonce() string {
|
||||
if t.noncer != nil {
|
||||
return t.noncer.Nonce()
|
||||
}
|
||||
b := make([]byte, 32)
|
||||
rand.Read(b)
|
||||
return base64.StdEncoding.EncodeToString(b)
|
||||
}
|
||||
|
||||
// Returns the Unix epoch seconds.
|
||||
func (t *Transport) epoch() int64 {
|
||||
if t.clock != nil {
|
||||
return t.clock.Now().Unix()
|
||||
}
|
||||
return time.Now().Unix()
|
||||
}
|
||||
|
||||
// authHeaderValue formats OAuth parameters according to
|
||||
// RFC 5849 3.5.1.
|
||||
func authHeaderValue(oauthParams map[string]string) string {
|
||||
pairs := sortParameters(encodeParameters(oauthParams), `%s="%s"`)
|
||||
return "OAuth " + strings.Join(pairs, ", ")
|
||||
}
|
||||
|
||||
// collectParameters returns a map of request parameter keys
|
||||
// and values as defined in RFC 5849 3.4.1.3.
|
||||
func collectParameters(r *http.Request, oauthParams map[string]string) map[string]string {
|
||||
params := map[string]string{}
|
||||
for key, value := range r.URL.Query() {
|
||||
params[key] = value[0]
|
||||
}
|
||||
for key, value := range oauthParams {
|
||||
params[key] = value
|
||||
}
|
||||
return params
|
||||
}
|
@ -0,0 +1,5 @@
|
||||
// Copyright 2018 Drone.IO Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package oauth1
|
@ -0,0 +1,35 @@
|
||||
// Copyright (c) 2015 Dalton Hubble. All rights reserved.
|
||||
// Copyrights licensed under the MIT License.
|
||||
|
||||
package oauth1
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// signatureBase returns the OAuth1 signature base string
|
||||
// according to RFC5849 3.4.1.
|
||||
func signatureBase(req *http.Request, params map[string]string) string {
|
||||
method := strings.ToUpper(req.Method)
|
||||
baseURL := baseURI(req)
|
||||
parameterString := encodeParameterString(params)
|
||||
baseParts := []string{method,
|
||||
percentEncode(baseURL),
|
||||
percentEncode(parameterString)}
|
||||
return strings.Join(baseParts, "&")
|
||||
}
|
||||
|
||||
// sign calculates the signature of the message SHA1 digests
|
||||
// using the given RSA private key.
|
||||
func sign(privateKey *rsa.PrivateKey, message string) (string, error) {
|
||||
digest := sha1.Sum([]byte(message))
|
||||
signature, err := rsa.SignPKCS1v15(
|
||||
rand.Reader, privateKey, crypto.SHA1, digest[:])
|
||||
return base64.StdEncoding.EncodeToString(signature), err
|
||||
}
|
@ -0,0 +1,5 @@
|
||||
// Copyright 2018 Drone.IO Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package oauth1
|
@ -0,0 +1,28 @@
|
||||
// Copyright (c) 2015 Dalton Hubble. All rights reserved.
|
||||
// Copyrights licensed under the MIT License.
|
||||
|
||||
package oauth1
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// sortParameters sorts parameters by key and returns a
|
||||
// slice of formatted key value pairs.
|
||||
func sortParameters(params map[string]string, format string) []string {
|
||||
// sort by key
|
||||
keys := make([]string, len(params))
|
||||
i := 0
|
||||
for key := range params {
|
||||
keys[i] = key
|
||||
i++
|
||||
}
|
||||
sort.Strings(keys)
|
||||
// parameter join
|
||||
pairs := make([]string, len(params))
|
||||
for i, key := range keys {
|
||||
pairs[i] = fmt.Sprintf(format, key, params[key])
|
||||
}
|
||||
return pairs
|
||||
}
|
@ -0,0 +1,33 @@
|
||||
// Copyright 2018 Drone.IO Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package oauth1
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestSortParameters(t *testing.T) {
|
||||
params := map[string]string{
|
||||
"page": "1",
|
||||
"per_page": "25",
|
||||
"oauth_version": "1.0",
|
||||
"oauth_signature_method": "RSA-SHA1",
|
||||
"oauth_consumer_key": "12345",
|
||||
}
|
||||
want := []string{
|
||||
"oauth_consumer_key=12345",
|
||||
"oauth_signature_method=RSA-SHA1",
|
||||
"oauth_version=1.0",
|
||||
"page=1",
|
||||
"per_page=25",
|
||||
}
|
||||
got := sortParameters(params, "%s=%s")
|
||||
if diff := cmp.Diff(got, want); diff != "" {
|
||||
t.Errorf("Unexpected Results")
|
||||
t.Log(diff)
|
||||
}
|
||||
}
|
@ -0,0 +1,41 @@
|
||||
// Copyright 2018 Drone.IO Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package oauth1
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/drone/go-scm/scm"
|
||||
)
|
||||
|
||||
// StaticTokenSource returns a TokenSource that always
|
||||
// returns the same token. Because the provided token t
|
||||
// is never refreshed, StaticTokenSource is only useful
|
||||
// for tokens that never expire.
|
||||
func StaticTokenSource(t *scm.Token) scm.TokenSource {
|
||||
return staticTokenSource{t}
|
||||
}
|
||||
|
||||
type staticTokenSource struct {
|
||||
token *scm.Token
|
||||
}
|
||||
|
||||
func (s staticTokenSource) Token(context.Context) (*scm.Token, error) {
|
||||
return s.token, nil
|
||||
}
|
||||
|
||||
// ContextTokenSource returns a TokenSource that returns
|
||||
// a token from the http.Request context.
|
||||
func ContextTokenSource() scm.TokenSource {
|
||||
return contextTokenSource{}
|
||||
}
|
||||
|
||||
type contextTokenSource struct {
|
||||
}
|
||||
|
||||
func (s contextTokenSource) Token(ctx context.Context) (*scm.Token, error) {
|
||||
token, _ := ctx.Value(scm.TokenKey{}).(*scm.Token)
|
||||
return token, nil
|
||||
}
|
@ -0,0 +1,42 @@
|
||||
// Copyright 2018 Drone.IO Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package oauth1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/drone/go-scm/scm"
|
||||
)
|
||||
|
||||
func TestContextTokenSource(t *testing.T) {
|
||||
source := ContextTokenSource()
|
||||
want := new(scm.Token)
|
||||
|
||||
ctx := context.Background()
|
||||
ctx = context.WithValue(ctx, scm.TokenKey{}, want)
|
||||
got, err := source.Token(ctx)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
if got != want {
|
||||
t.Errorf("Expect token retrieved from Context")
|
||||
}
|
||||
}
|
||||
|
||||
func TestContextTokenSource_Nil(t *testing.T) {
|
||||
source := ContextTokenSource()
|
||||
|
||||
ctx := context.Background()
|
||||
token, err := source.Token(ctx)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
if token != nil {
|
||||
t.Errorf("Expect nil token from Context")
|
||||
}
|
||||
}
|
@ -0,0 +1,22 @@
|
||||
// Copyright (c) 2015 Dalton Hubble. All rights reserved.
|
||||
// Copyrights licensed under the MIT License.
|
||||
|
||||
package oauth1
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// baseURI returns the base string URI of a request
|
||||
// according to RFC 5849 3.4.1.2.
|
||||
func baseURI(r *http.Request) string {
|
||||
scheme := strings.ToLower(r.URL.Scheme)
|
||||
host := strings.ToLower(r.URL.Host)
|
||||
if hostPort := strings.Split(host, ":"); len(hostPort) == 2 && (hostPort[1] == "80" || hostPort[1] == "443") {
|
||||
host = hostPort[0]
|
||||
}
|
||||
path := r.URL.EscapedPath()
|
||||
return fmt.Sprintf("%v://%v%v", scheme, host, path)
|
||||
}
|
@ -0,0 +1,42 @@
|
||||
// Copyright 2018 Drone.IO Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package oauth1
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBaseURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
before string
|
||||
after string
|
||||
}{
|
||||
{
|
||||
before: "HTTP://EXAMPLE.COM:80/r%20v/X?id=123",
|
||||
after: "http://example.com/r%20v/X",
|
||||
},
|
||||
{
|
||||
before: "http://example.com:80",
|
||||
after: "http://example.com",
|
||||
},
|
||||
{
|
||||
before: "https://example.com:443",
|
||||
after: "https://example.com",
|
||||
},
|
||||
{
|
||||
before: "http://www.example.com:8080/?q=1",
|
||||
after: "http://www.example.com:8080/",
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
r := new(http.Request)
|
||||
r.URL, _ = url.Parse(test.before)
|
||||
if got, want := baseURI(r), test.after; got != want {
|
||||
t.Errorf("Want url %s, got %s", want, got)
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,45 @@
|
||||
// Copyright 2018 Drone.IO Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package oauth2
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/drone/go-scm/scm"
|
||||
"github.com/drone/go-scm/scm/transport/internal"
|
||||
)
|
||||
|
||||
// Transport is an http.RoundTripper that refreshes oauth
|
||||
// tokens, wrapping a base RoundTripper and refreshing the
|
||||
// token if expired.
|
||||
type Transport struct {
|
||||
Source scm.TokenSource
|
||||
Base http.RoundTripper
|
||||
}
|
||||
|
||||
// RoundTrip authorizes and authenticates the request with
|
||||
// an access token from the request context.
|
||||
func (t *Transport) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
ctx := r.Context()
|
||||
token, err := t.Source.Token(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if token == nil {
|
||||
return t.base().RoundTrip(r)
|
||||
}
|
||||
r2 := internal.CloneRequest(r)
|
||||
r2.Header.Set("Authorization", "Bearer "+token.Token)
|
||||
return t.base().RoundTrip(r2)
|
||||
}
|
||||
|
||||
// base returns the base transport. If no base transport
|
||||
// is configured, the default transport is returned.
|
||||
func (t *Transport) base() http.RoundTripper {
|
||||
if t.Base != nil {
|
||||
return t.Base
|
||||
}
|
||||
return http.DefaultTransport
|
||||
}
|
@ -0,0 +1,85 @@
|
||||
// Copyright 2018 Drone.IO Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package oauth2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/drone/go-scm/scm"
|
||||
|
||||
"github.com/h2non/gock"
|
||||
)
|
||||
|
||||
func TestTransport(t *testing.T) {
|
||||
defer gock.Off()
|
||||
|
||||
gock.New("https://api.github.com").
|
||||
Get("/user").
|
||||
MatchHeader("Authorization", "Bearer mF_9.B5f-4.1JqM").
|
||||
Reply(200)
|
||||
|
||||
client := &http.Client{
|
||||
Transport: &Transport{
|
||||
Source: StaticTokenSource(
|
||||
&scm.Token{
|
||||
Token: "mF_9.B5f-4.1JqM",
|
||||
},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
res, err := client.Get("https://api.github.com/user")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
defer res.Body.Close()
|
||||
}
|
||||
|
||||
func TestTransport_NoToken(t *testing.T) {
|
||||
defer gock.Off()
|
||||
|
||||
gock.New("https://api.github.com").
|
||||
Get("/user").
|
||||
Reply(200)
|
||||
|
||||
client := &http.Client{
|
||||
Transport: &Transport{
|
||||
Source: ContextTokenSource(),
|
||||
},
|
||||
}
|
||||
|
||||
res, err := client.Get("https://api.github.com/user")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
defer res.Body.Close()
|
||||
}
|
||||
|
||||
func TestTransport_TokenError(t *testing.T) {
|
||||
want := errors.New("Cannot retrieve token")
|
||||
client := &http.Client{
|
||||
Transport: &Transport{
|
||||
Source: mockErrorSource{want},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := client.Get("https://api.github.com/user")
|
||||
if err == nil {
|
||||
t.Errorf("Expect token source error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
type mockErrorSource struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (s mockErrorSource) Token(ctx context.Context) (*scm.Token, error) {
|
||||
return nil, s.err
|
||||
}
|
@ -0,0 +1,41 @@
|
||||
// Copyright 2018 Drone.IO Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package oauth2
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/drone/go-scm/scm"
|
||||
)
|
||||
|
||||
// StaticTokenSource returns a TokenSource that always
|
||||
// returns the same token. Because the provided token t
|
||||
// is never refreshed, StaticTokenSource is only useful
|
||||
// for tokens that never expire.
|
||||
func StaticTokenSource(t *scm.Token) scm.TokenSource {
|
||||
return staticTokenSource{t}
|
||||
}
|
||||
|
||||
type staticTokenSource struct {
|
||||
token *scm.Token
|
||||
}
|
||||
|
||||
func (s staticTokenSource) Token(context.Context) (*scm.Token, error) {
|
||||
return s.token, nil
|
||||
}
|
||||
|
||||
// ContextTokenSource returns a TokenSource that returns
|
||||
// a token from the http.Request context.
|
||||
func ContextTokenSource() scm.TokenSource {
|
||||
return contextTokenSource{}
|
||||
}
|
||||
|
||||
type contextTokenSource struct {
|
||||
}
|
||||
|
||||
func (s contextTokenSource) Token(ctx context.Context) (*scm.Token, error) {
|
||||
token, _ := ctx.Value(scm.TokenKey{}).(*scm.Token)
|
||||
return token, nil
|
||||
}
|
@ -0,0 +1,42 @@
|
||||
// Copyright 2018 Drone.IO Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package oauth2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/drone/go-scm/scm"
|
||||
)
|
||||
|
||||
func TestContextTokenSource(t *testing.T) {
|
||||
source := ContextTokenSource()
|
||||
want := new(scm.Token)
|
||||
|
||||
ctx := context.Background()
|
||||
ctx = context.WithValue(ctx, scm.TokenKey{}, want)
|
||||
got, err := source.Token(ctx)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
if got != want {
|
||||
t.Errorf("Expect token retrieved from Context")
|
||||
}
|
||||
}
|
||||
|
||||
func TestContextTokenSource_Nil(t *testing.T) {
|
||||
source := ContextTokenSource()
|
||||
|
||||
ctx := context.Background()
|
||||
token, err := source.Token(ctx)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
if token != nil {
|
||||
t.Errorf("Expect nil token from Context")
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue