add oauth1 and oauth2 transports with context sources

feature/refresh
Brad Rydzewski 6 years ago
parent 7ef3597551
commit 709496f6e5

@ -0,0 +1,34 @@
// Copyright 2018 Drone.IO Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package scm
import (
"context"
"time"
)
type (
// Token represents the credentials used to authorize
// the requests to access protected resources.
Token struct {
Token string
Refresh string
Expires time.Time
}
// TokenSource returns a token.
TokenSource interface {
Token(context.Context) (*Token, error)
}
// TokenKey is the key to use with the context.WithValue
// function to associate an Token value with a context.
TokenKey struct{}
)
// WithContext returns a copy of parent in which the token value is set
func WithContext(parent context.Context, token *Token) context.Context {
return context.WithValue(parent, TokenKey{}, token)
}

@ -0,0 +1,22 @@
// Copyright 2018 Drone.IO Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package internal
import "net/http"
// CloneRequest returns a clone of the provided
// http.Request. The clone is a shallow copy of the struct
// and its Header map.
func CloneRequest(r *http.Request) *http.Request {
// shallow copy of the struct
r2 := new(http.Request)
*r2 = *r
// deep copy of the Header
r2.Header = make(http.Header, len(r.Header))
for k, s := range r.Header {
r2.Header[k] = append([]string(nil), s...)
}
return r2
}

@ -0,0 +1,28 @@
// Copyright 2018 Drone.IO Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package internal
import (
"bytes"
"net/http"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestCloneRequest(t *testing.T) {
b := new(bytes.Buffer)
r1, _ := http.NewRequest("GET", "http://example.com", b)
r1.Header.Set("Accept", "application/json")
r1.Header.Set("Etag", "1")
r2 := CloneRequest(r1)
if r1 == r2 {
t.Errorf("Expect http.Request cloned")
}
if diff := cmp.Diff(r1.Header, r2.Header); diff != "" {
t.Errorf("Expect http.Header cloned")
t.Log(diff)
}
}

@ -0,0 +1,60 @@
// Copyright (c) 2015 Dalton Hubble. All rights reserved.
// Copyrights licensed under the MIT License.
package oauth1
import (
"bytes"
"fmt"
"strings"
)
// encodeParameterString encodes collected OAuth parameters
// into a parameter string as defined in RFC 5894 3.4.1.3.2.
func encodeParameterString(params map[string]string) string {
return strings.Join(sortParameters(
encodeParameters(params), "%s=%s"), "&")
}
// encodeParameters percent encodes parameter keys and
// values according to RFC5849 3.6 and RFC3986 2.1 and
// returns a new map.
func encodeParameters(params map[string]string) map[string]string {
encoded := map[string]string{}
for key, value := range params {
encoded[percentEncode(key)] = percentEncode(value)
}
return encoded
}
// percentEncode percent encodes a string according to
// RFC 3986 2.1.
func percentEncode(input string) string {
var buf bytes.Buffer
for _, b := range []byte(input) {
// if in unreserved set
if shouldEscape(b) {
buf.Write([]byte(fmt.Sprintf("%%%02X", b)))
} else {
// do not escape, write byte as-is
buf.WriteByte(b)
}
}
return buf.String()
}
// shouldEscape returns false if the byte is an unreserved
// character that should not be escaped and true otherwise,
// according to RFC 3986 2.1.
func shouldEscape(c byte) bool {
// RFC3986 2.3 unreserved characters
if 'A' <= c && c <= 'Z' || 'a' <= c && c <= 'z' || '0' <= c && c <= '9' {
return false
}
switch c {
case '-', '.', '_', '~':
return false
}
// all other bytes must be escaped
return true
}

@ -0,0 +1,60 @@
// Copyright (c) 2015 Dalton Hubble. All rights reserved.
// Copyrights licensed under the MIT License.
package oauth1
import (
"testing"
"github.com/google/go-cmp/cmp"
)
func TestEncodeParameterString(t *testing.T) {
params := map[string]string{
"key 1": "key 2",
"key+3": "key+4",
}
want := "key%201=key%202&key%2B3=key%2B4"
got := encodeParameterString(params)
if got != want {
t.Errorf("Want encoded string %s, got %s", want, got)
}
}
func TestEncodeParameters(t *testing.T) {
params := map[string]string{
"key 1": "key 2",
"key+3": "key+4",
}
want := map[string]string{
"key%201": "key%202",
"key%2B3": "key%2B4",
}
got := encodeParameters(params)
if diff := cmp.Diff(got, want); diff != "" {
t.Errorf("Unexpected Results")
t.Log(diff)
}
}
func TestPercentEncode(t *testing.T) {
cases := []struct {
input string
expected string
}{
{" ", "%20"},
{"%", "%25"},
{"&", "%26"},
{"-._", "-._"},
{" /=+", "%20%2F%3D%2B"},
{"Ladies + Gentlemen", "Ladies%20%2B%20Gentlemen"},
{"An encoded string!", "An%20encoded%20string%21"},
{"Dogs, Cats & Mice", "Dogs%2C%20Cats%20%26%20Mice"},
{"☃", "%E2%98%83"},
}
for _, c := range cases {
if output := percentEncode(c.input); output != c.expected {
t.Errorf("expected %s, got %s", c.expected, output)
}
}
}

@ -0,0 +1,147 @@
// Copyright 2018 Drone.IO Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package oauth1
import (
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"net/http"
"strconv"
"strings"
"time"
"github.com/drone/go-scm/scm"
"github.com/drone/go-scm/scm/transport/internal"
)
// clock provides a interface for current time providers. A Clock can be used
// in place of calling time.Now() directly.
type clock interface {
Now() time.Time
}
// A noncer provides random nonce strings.
type noncer interface {
Nonce() string
}
// Transport is an http.RoundTripper that refreshes oauth
// tokens, wrapping a base RoundTripper and refreshing the
// token if expired.
type Transport struct {
// Consumer Key
ConsumerKey string
// Consumer Private Key
PrivateKey *rsa.PrivateKey
// Source supplies the Token to add to the request
// Authorization headers.
Source scm.TokenSource
// Base is the base RoundTripper used to make requests.
// If nil, http.DefaultTransport is used.
Base http.RoundTripper
noncer noncer
clock clock
}
// RoundTrip authorizes and authenticates the request with
// an access token from the request context.
func (t *Transport) RoundTrip(r *http.Request) (*http.Response, error) {
ctx := r.Context()
token, err := t.Source.Token(ctx)
if err != nil {
return nil, err
}
if token == nil {
return t.base().RoundTrip(r)
}
r2 := internal.CloneRequest(r)
err = t.setRequestAuthHeader(r2, token)
if err != nil {
return nil, err
}
return t.base().RoundTrip(r2)
}
// base returns the base transport. If no base transport
// is configured, the default transport is returned.
func (t *Transport) base() http.RoundTripper {
if t.Base != nil {
return t.Base
}
return http.DefaultTransport
}
// setRequestAuthHeader sets the OAuth1 header for making
// authenticated requests with an AccessToken according to
// RFC 5849 3.1.
func (t *Transport) setRequestAuthHeader(r *http.Request, token *scm.Token) error {
oauthParams := t.commonOAuthParams()
oauthParams["oauth_token"] = token.Token
params := collectParameters(r, oauthParams)
signatureBase := signatureBase(r, params)
signature, err := sign(t.PrivateKey, signatureBase)
if err != nil {
return err
}
oauthParams["oauth_signature"] = signature
r.Header.Set("Authorization", authHeaderValue(oauthParams))
return nil
}
// commonOAuthParams returns a map of the common OAuth1
// protocol parameters, excluding the oauth_signature.
func (t *Transport) commonOAuthParams() map[string]string {
return map[string]string{
"oauth_consumer_key": t.ConsumerKey,
"oauth_signature_method": "RSA-SHA1",
"oauth_timestamp": strconv.FormatInt(t.epoch(), 10),
"oauth_nonce": t.nonce(),
"oauth_version": "1.0",
}
}
// Returns a base64 encoded random 32 byte string.
func (t *Transport) nonce() string {
if t.noncer != nil {
return t.noncer.Nonce()
}
b := make([]byte, 32)
rand.Read(b)
return base64.StdEncoding.EncodeToString(b)
}
// Returns the Unix epoch seconds.
func (t *Transport) epoch() int64 {
if t.clock != nil {
return t.clock.Now().Unix()
}
return time.Now().Unix()
}
// authHeaderValue formats OAuth parameters according to
// RFC 5849 3.5.1.
func authHeaderValue(oauthParams map[string]string) string {
pairs := sortParameters(encodeParameters(oauthParams), `%s="%s"`)
return "OAuth " + strings.Join(pairs, ", ")
}
// collectParameters returns a map of request parameter keys
// and values as defined in RFC 5849 3.4.1.3.
func collectParameters(r *http.Request, oauthParams map[string]string) map[string]string {
params := map[string]string{}
for key, value := range r.URL.Query() {
params[key] = value[0]
}
for key, value := range oauthParams {
params[key] = value
}
return params
}

@ -0,0 +1,5 @@
// Copyright 2018 Drone.IO Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package oauth1

@ -0,0 +1,35 @@
// Copyright (c) 2015 Dalton Hubble. All rights reserved.
// Copyrights licensed under the MIT License.
package oauth1
import (
"crypto"
"crypto/rand"
"crypto/rsa"
"crypto/sha1"
"encoding/base64"
"net/http"
"strings"
)
// signatureBase returns the OAuth1 signature base string
// according to RFC5849 3.4.1.
func signatureBase(req *http.Request, params map[string]string) string {
method := strings.ToUpper(req.Method)
baseURL := baseURI(req)
parameterString := encodeParameterString(params)
baseParts := []string{method,
percentEncode(baseURL),
percentEncode(parameterString)}
return strings.Join(baseParts, "&")
}
// sign calculates the signature of the message SHA1 digests
// using the given RSA private key.
func sign(privateKey *rsa.PrivateKey, message string) (string, error) {
digest := sha1.Sum([]byte(message))
signature, err := rsa.SignPKCS1v15(
rand.Reader, privateKey, crypto.SHA1, digest[:])
return base64.StdEncoding.EncodeToString(signature), err
}

@ -0,0 +1,5 @@
// Copyright 2018 Drone.IO Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package oauth1

@ -0,0 +1,28 @@
// Copyright (c) 2015 Dalton Hubble. All rights reserved.
// Copyrights licensed under the MIT License.
package oauth1
import (
"fmt"
"sort"
)
// sortParameters sorts parameters by key and returns a
// slice of formatted key value pairs.
func sortParameters(params map[string]string, format string) []string {
// sort by key
keys := make([]string, len(params))
i := 0
for key := range params {
keys[i] = key
i++
}
sort.Strings(keys)
// parameter join
pairs := make([]string, len(params))
for i, key := range keys {
pairs[i] = fmt.Sprintf(format, key, params[key])
}
return pairs
}

@ -0,0 +1,33 @@
// Copyright 2018 Drone.IO Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package oauth1
import (
"testing"
"github.com/google/go-cmp/cmp"
)
func TestSortParameters(t *testing.T) {
params := map[string]string{
"page": "1",
"per_page": "25",
"oauth_version": "1.0",
"oauth_signature_method": "RSA-SHA1",
"oauth_consumer_key": "12345",
}
want := []string{
"oauth_consumer_key=12345",
"oauth_signature_method=RSA-SHA1",
"oauth_version=1.0",
"page=1",
"per_page=25",
}
got := sortParameters(params, "%s=%s")
if diff := cmp.Diff(got, want); diff != "" {
t.Errorf("Unexpected Results")
t.Log(diff)
}
}

@ -0,0 +1,41 @@
// Copyright 2018 Drone.IO Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package oauth1
import (
"context"
"github.com/drone/go-scm/scm"
)
// StaticTokenSource returns a TokenSource that always
// returns the same token. Because the provided token t
// is never refreshed, StaticTokenSource is only useful
// for tokens that never expire.
func StaticTokenSource(t *scm.Token) scm.TokenSource {
return staticTokenSource{t}
}
type staticTokenSource struct {
token *scm.Token
}
func (s staticTokenSource) Token(context.Context) (*scm.Token, error) {
return s.token, nil
}
// ContextTokenSource returns a TokenSource that returns
// a token from the http.Request context.
func ContextTokenSource() scm.TokenSource {
return contextTokenSource{}
}
type contextTokenSource struct {
}
func (s contextTokenSource) Token(ctx context.Context) (*scm.Token, error) {
token, _ := ctx.Value(scm.TokenKey{}).(*scm.Token)
return token, nil
}

@ -0,0 +1,42 @@
// Copyright 2018 Drone.IO Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package oauth1
import (
"context"
"testing"
"github.com/drone/go-scm/scm"
)
func TestContextTokenSource(t *testing.T) {
source := ContextTokenSource()
want := new(scm.Token)
ctx := context.Background()
ctx = context.WithValue(ctx, scm.TokenKey{}, want)
got, err := source.Token(ctx)
if err != nil {
t.Error(err)
return
}
if got != want {
t.Errorf("Expect token retrieved from Context")
}
}
func TestContextTokenSource_Nil(t *testing.T) {
source := ContextTokenSource()
ctx := context.Background()
token, err := source.Token(ctx)
if err != nil {
t.Error(err)
return
}
if token != nil {
t.Errorf("Expect nil token from Context")
}
}

@ -0,0 +1,22 @@
// Copyright (c) 2015 Dalton Hubble. All rights reserved.
// Copyrights licensed under the MIT License.
package oauth1
import (
"fmt"
"net/http"
"strings"
)
// baseURI returns the base string URI of a request
// according to RFC 5849 3.4.1.2.
func baseURI(r *http.Request) string {
scheme := strings.ToLower(r.URL.Scheme)
host := strings.ToLower(r.URL.Host)
if hostPort := strings.Split(host, ":"); len(hostPort) == 2 && (hostPort[1] == "80" || hostPort[1] == "443") {
host = hostPort[0]
}
path := r.URL.EscapedPath()
return fmt.Sprintf("%v://%v%v", scheme, host, path)
}

@ -0,0 +1,42 @@
// Copyright 2018 Drone.IO Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package oauth1
import (
"net/http"
"net/url"
"testing"
)
func TestBaseURL(t *testing.T) {
tests := []struct {
before string
after string
}{
{
before: "HTTP://EXAMPLE.COM:80/r%20v/X?id=123",
after: "http://example.com/r%20v/X",
},
{
before: "http://example.com:80",
after: "http://example.com",
},
{
before: "https://example.com:443",
after: "https://example.com",
},
{
before: "http://www.example.com:8080/?q=1",
after: "http://www.example.com:8080/",
},
}
for _, test := range tests {
r := new(http.Request)
r.URL, _ = url.Parse(test.before)
if got, want := baseURI(r), test.after; got != want {
t.Errorf("Want url %s, got %s", want, got)
}
}
}

@ -0,0 +1,45 @@
// Copyright 2018 Drone.IO Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package oauth2
import (
"net/http"
"github.com/drone/go-scm/scm"
"github.com/drone/go-scm/scm/transport/internal"
)
// Transport is an http.RoundTripper that refreshes oauth
// tokens, wrapping a base RoundTripper and refreshing the
// token if expired.
type Transport struct {
Source scm.TokenSource
Base http.RoundTripper
}
// RoundTrip authorizes and authenticates the request with
// an access token from the request context.
func (t *Transport) RoundTrip(r *http.Request) (*http.Response, error) {
ctx := r.Context()
token, err := t.Source.Token(ctx)
if err != nil {
return nil, err
}
if token == nil {
return t.base().RoundTrip(r)
}
r2 := internal.CloneRequest(r)
r2.Header.Set("Authorization", "Bearer "+token.Token)
return t.base().RoundTrip(r2)
}
// base returns the base transport. If no base transport
// is configured, the default transport is returned.
func (t *Transport) base() http.RoundTripper {
if t.Base != nil {
return t.Base
}
return http.DefaultTransport
}

@ -0,0 +1,85 @@
// Copyright 2018 Drone.IO Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package oauth2
import (
"context"
"errors"
"net/http"
"testing"
"github.com/drone/go-scm/scm"
"github.com/h2non/gock"
)
func TestTransport(t *testing.T) {
defer gock.Off()
gock.New("https://api.github.com").
Get("/user").
MatchHeader("Authorization", "Bearer mF_9.B5f-4.1JqM").
Reply(200)
client := &http.Client{
Transport: &Transport{
Source: StaticTokenSource(
&scm.Token{
Token: "mF_9.B5f-4.1JqM",
},
),
},
}
res, err := client.Get("https://api.github.com/user")
if err != nil {
t.Error(err)
return
}
defer res.Body.Close()
}
func TestTransport_NoToken(t *testing.T) {
defer gock.Off()
gock.New("https://api.github.com").
Get("/user").
Reply(200)
client := &http.Client{
Transport: &Transport{
Source: ContextTokenSource(),
},
}
res, err := client.Get("https://api.github.com/user")
if err != nil {
t.Error(err)
return
}
defer res.Body.Close()
}
func TestTransport_TokenError(t *testing.T) {
want := errors.New("Cannot retrieve token")
client := &http.Client{
Transport: &Transport{
Source: mockErrorSource{want},
},
}
_, err := client.Get("https://api.github.com/user")
if err == nil {
t.Errorf("Expect token source error, got nil")
}
}
type mockErrorSource struct {
err error
}
func (s mockErrorSource) Token(ctx context.Context) (*scm.Token, error) {
return nil, s.err
}

@ -0,0 +1,41 @@
// Copyright 2018 Drone.IO Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package oauth2
import (
"context"
"github.com/drone/go-scm/scm"
)
// StaticTokenSource returns a TokenSource that always
// returns the same token. Because the provided token t
// is never refreshed, StaticTokenSource is only useful
// for tokens that never expire.
func StaticTokenSource(t *scm.Token) scm.TokenSource {
return staticTokenSource{t}
}
type staticTokenSource struct {
token *scm.Token
}
func (s staticTokenSource) Token(context.Context) (*scm.Token, error) {
return s.token, nil
}
// ContextTokenSource returns a TokenSource that returns
// a token from the http.Request context.
func ContextTokenSource() scm.TokenSource {
return contextTokenSource{}
}
type contextTokenSource struct {
}
func (s contextTokenSource) Token(ctx context.Context) (*scm.Token, error) {
token, _ := ctx.Value(scm.TokenKey{}).(*scm.Token)
return token, nil
}

@ -0,0 +1,42 @@
// Copyright 2018 Drone.IO Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package oauth2
import (
"context"
"testing"
"github.com/drone/go-scm/scm"
)
func TestContextTokenSource(t *testing.T) {
source := ContextTokenSource()
want := new(scm.Token)
ctx := context.Background()
ctx = context.WithValue(ctx, scm.TokenKey{}, want)
got, err := source.Token(ctx)
if err != nil {
t.Error(err)
return
}
if got != want {
t.Errorf("Expect token retrieved from Context")
}
}
func TestContextTokenSource_Nil(t *testing.T) {
source := ContextTokenSource()
ctx := context.Background()
token, err := source.Token(ctx)
if err != nil {
t.Error(err)
return
}
if token != nil {
t.Errorf("Expect nil token from Context")
}
}
Loading…
Cancel
Save