diff options
author | Mike Crute <mike@crute.us> | 2017-09-19 04:39:36 +0000 |
---|---|---|
committer | Mike Crute <mike@crute.us> | 2017-09-19 04:39:36 +0000 |
commit | 9f7861ffe1397da514606b189f5b3e383f4e7ed7 (patch) | |
tree | 2bd145745efba52ac136166e4f4535cfd59359ea | |
parent | b7867d9cf5b0dd175b8167a552b830ebfe47d0ed (diff) | |
download | oidc_proxy-9f7861ffe1397da514606b189f5b3e383f4e7ed7.tar.bz2 oidc_proxy-9f7861ffe1397da514606b189f5b3e383f4e7ed7.tar.xz oidc_proxy-9f7861ffe1397da514606b189f5b3e383f4e7ed7.zip |
Finish out most of the proxy functionality
-rw-r--r-- | cautious_http_client.go | 87 | ||||
-rw-r--r-- | jwks_fetcher.go | 118 | ||||
-rw-r--r-- | jws_validator.go | 110 | ||||
-rw-r--r-- | key_validator.go | 36 | ||||
-rw-r--r-- | main.go | 409 | ||||
-rwxr-xr-x | oidc_proxy | bin | 6495319 -> 7509902 bytes | |||
-rw-r--r-- | util.go | 18 |
7 files changed, 534 insertions, 244 deletions
diff --git a/cautious_http_client.go b/cautious_http_client.go index 2f33ae0..34b736f 100644 --- a/cautious_http_client.go +++ b/cautious_http_client.go | |||
@@ -2,24 +2,29 @@ package main | |||
2 | 2 | ||
3 | import ( | 3 | import ( |
4 | "encoding/json" | 4 | "encoding/json" |
5 | "fmt" | 5 | "github.com/lox/httpcache" |
6 | "github.com/pkg/errors" | ||
6 | "net" | 7 | "net" |
7 | "net/http" | 8 | "net/http" |
8 | "net/url" | 9 | "net/url" |
10 | "strings" | ||
9 | "time" | 11 | "time" |
10 | ) | 12 | ) |
11 | 13 | ||
12 | type CautiousHTTPClient interface { | 14 | type CautiousHTTPClient interface { |
13 | Get(string) (*http.Response, error) | 15 | Get(string) (*http.Response, error) |
14 | GetJSON(string, interface{}) error | 16 | GetJSON(string, interface{}) error |
17 | GetJSONExpires(string, interface{}) (time.Duration, error) | ||
15 | } | 18 | } |
16 | 19 | ||
17 | type cautiousHttpClient struct { | 20 | type cautiousHttpClient struct { |
18 | client *http.Client | 21 | allowHttp bool |
22 | client *http.Client | ||
19 | } | 23 | } |
20 | 24 | ||
21 | func NewCautiousHTTPClient() CautiousHTTPClient { | 25 | // allowHttp is UNSAFE and technically validates the spec but it does make it |
22 | // May Need: TLSClientConfig *tls.Config | 26 | // easier to work in dev so leaving it in for now |
27 | func NewCautiousHTTPClient(allowHttp bool) (CautiousHTTPClient, error) { | ||
23 | CautiousTransport := &http.Transport{ | 28 | CautiousTransport := &http.Transport{ |
24 | Proxy: http.ProxyFromEnvironment, | 29 | Proxy: http.ProxyFromEnvironment, |
25 | DialContext: (&net.Dialer{ | 30 | DialContext: (&net.Dialer{ |
@@ -36,44 +41,100 @@ func NewCautiousHTTPClient() CautiousHTTPClient { | |||
36 | } | 41 | } |
37 | 42 | ||
38 | return &cautiousHttpClient{ | 43 | return &cautiousHttpClient{ |
44 | allowHttp: allowHttp, | ||
39 | client: &http.Client{ | 45 | client: &http.Client{ |
40 | Transport: CautiousTransport, | 46 | Transport: CautiousTransport, |
41 | Timeout: 30 * time.Second, | 47 | Timeout: 30 * time.Second, |
42 | }, | 48 | }, |
43 | } | 49 | }, nil |
44 | } | 50 | } |
45 | 51 | ||
46 | func (c *cautiousHttpClient) Get(gurl string) (*http.Response, error) { | 52 | func (c *cautiousHttpClient) Get(gurl string) (*http.Response, error) { |
47 | u, err := url.Parse(gurl) | 53 | u, err := url.Parse(gurl) |
48 | if err != nil { | 54 | if err != nil { |
49 | return nil, err | 55 | return nil, errors.WithStack(err) |
50 | } | 56 | } |
51 | 57 | ||
52 | // TODO | 58 | if u.Scheme != "https" && !c.allowHttp { |
53 | if u.Scheme != "https" && false { | 59 | return nil, errors.Errorf("URL for GET must be secure") |
54 | return nil, fmt.Errorf("URL for GET must be secure") | ||
55 | } | 60 | } |
56 | 61 | ||
57 | r, err := c.client.Get(u.String()) | 62 | r, err := c.client.Get(u.String()) |
58 | if err != nil { | 63 | if err != nil { |
59 | return nil, err | 64 | return nil, errors.WithStack(err) |
60 | } | 65 | } |
61 | r.Body = http.MaxBytesReader(nil, r.Body, 1000000) | 66 | r.Body = http.MaxBytesReader(nil, r.Body, 1000000) |
62 | return r, err | 67 | |
68 | return r, nil | ||
63 | } | 69 | } |
64 | 70 | ||
65 | func (c *cautiousHttpClient) GetJSON(url string, rv interface{}) error { | 71 | func (c *cautiousHttpClient) GetJSON(url string, rv interface{}) error { |
66 | r, err := c.Get(url) | 72 | r, err := c.Get(url) |
67 | if err != nil { | 73 | if err != nil { |
68 | return err | 74 | return errors.WithStack(err) |
69 | } | 75 | } |
70 | defer r.Body.Close() | 76 | defer r.Body.Close() |
71 | 77 | ||
72 | d := json.NewDecoder(r.Body) | 78 | d := json.NewDecoder(r.Body) |
73 | err = d.Decode(rv) | 79 | err = d.Decode(rv) |
74 | if err != nil { | 80 | if err != nil { |
75 | return err | 81 | return errors.WithStack(err) |
76 | } | 82 | } |
77 | 83 | ||
78 | return nil | 84 | return nil |
79 | } | 85 | } |
86 | |||
87 | func (c *cautiousHttpClient) GetJSONExpires(url string, rv interface{}) (time.Duration, error) { | ||
88 | r, err := c.Get(url) | ||
89 | if err != nil { | ||
90 | return time.Duration(0), errors.WithStack(err) | ||
91 | } | ||
92 | defer r.Body.Close() | ||
93 | |||
94 | res := httpcache.NewResource(r.StatusCode, nil, r.Header) | ||
95 | |||
96 | d := json.NewDecoder(r.Body) | ||
97 | err = d.Decode(rv) | ||
98 | if err != nil { | ||
99 | return time.Duration(0), errors.WithStack(err) | ||
100 | } | ||
101 | |||
102 | return refreshAfter(res), nil | ||
103 | } | ||
104 | |||
105 | type JSONURL struct { | ||
106 | *url.URL | ||
107 | } | ||
108 | |||
109 | func (u *JSONURL) AsURL() *url.URL { | ||
110 | return u.URL | ||
111 | } | ||
112 | |||
113 | func (u *JSONURL) UnmarshalJSON(data []byte) error { | ||
114 | d := strings.Trim(string(data), "\"") | ||
115 | pu, err := url.Parse(d) | ||
116 | if err != nil { | ||
117 | return errors.WithStack(err) | ||
118 | } | ||
119 | |||
120 | u.URL = pu | ||
121 | return nil | ||
122 | } | ||
123 | |||
124 | func refreshAfter(res *httpcache.Resource) time.Duration { | ||
125 | maxAge, err := res.MaxAge(false) | ||
126 | if err != nil { | ||
127 | return time.Duration(0) | ||
128 | } | ||
129 | |||
130 | age, err := res.Age() | ||
131 | if err != nil { | ||
132 | return time.Duration(0) | ||
133 | } | ||
134 | |||
135 | if hFresh := res.HeuristicFreshness(); hFresh > maxAge { | ||
136 | maxAge = hFresh | ||
137 | } | ||
138 | |||
139 | return maxAge - age | ||
140 | } | ||
diff --git a/jwks_fetcher.go b/jwks_fetcher.go new file mode 100644 index 0000000..9925430 --- /dev/null +++ b/jwks_fetcher.go | |||
@@ -0,0 +1,118 @@ | |||
1 | package main | ||
2 | |||
3 | import ( | ||
4 | "github.com/pkg/errors" | ||
5 | "gopkg.in/square/go-jose.v2" | ||
6 | "log" | ||
7 | "net/url" | ||
8 | "time" | ||
9 | ) | ||
10 | |||
11 | const ( | ||
12 | REQUEST_BUFFER_SIZE = 10 | ||
13 | KEY_MAP_INITIAL_SIZE = 5 | ||
14 | DEFAULT_REFRESH_INTERVAL = 15 * time.Minute | ||
15 | MIN_REFRESH_INTERVAL = 1 * time.Minute | ||
16 | ) | ||
17 | |||
18 | type KeyRequest struct { | ||
19 | KeyId string | ||
20 | Response chan *jose.JSONWebKey | ||
21 | } | ||
22 | |||
23 | type JWKSFetcher interface { | ||
24 | Run() | ||
25 | Fetch() error | ||
26 | GetKey(string) (*jose.JSONWebKey, error) | ||
27 | Done() | ||
28 | } | ||
29 | |||
30 | type jwksFetcher struct { | ||
31 | keyMap map[string]jose.JSONWebKey | ||
32 | httpClient CautiousHTTPClient | ||
33 | validator KeyValidator | ||
34 | fetchTimer *time.Timer | ||
35 | url *url.URL | ||
36 | requests chan *KeyRequest | ||
37 | done chan bool | ||
38 | } | ||
39 | |||
40 | func NewJWKSFetcher(h CautiousHTTPClient, url *url.URL, issuer string, root string) JWKSFetcher { | ||
41 | val := NewKeyValidator(HostFromURL(issuer)) | ||
42 | val.LoadRootPEM(root) | ||
43 | |||
44 | return &jwksFetcher{ | ||
45 | httpClient: h, | ||
46 | validator: val, | ||
47 | url: url, | ||
48 | fetchTimer: time.NewTimer(DEFAULT_REFRESH_INTERVAL), | ||
49 | requests: make(chan *KeyRequest, REQUEST_BUFFER_SIZE), | ||
50 | keyMap: make(map[string]jose.JSONWebKey, KEY_MAP_INITIAL_SIZE), | ||
51 | done: make(chan bool), | ||
52 | } | ||
53 | } | ||
54 | |||
55 | func (f *jwksFetcher) Fetch() error { | ||
56 | var jwks jose.JSONWebKeySet | ||
57 | timeout, err := f.httpClient.GetJSONExpires(f.url.String(), &jwks) | ||
58 | if err != nil { | ||
59 | return errors.WithStack(err) | ||
60 | } | ||
61 | |||
62 | for _, k := range jwks.Keys { | ||
63 | err = f.validator.Validate(k) | ||
64 | if err == nil { | ||
65 | f.keyMap[k.KeyID] = k | ||
66 | } else { | ||
67 | log.Printf("Rejecting key %q because %q", k.KeyID, err) | ||
68 | } | ||
69 | } | ||
70 | |||
71 | if timeout < MIN_REFRESH_INTERVAL { | ||
72 | timeout = MIN_REFRESH_INTERVAL | ||
73 | } | ||
74 | |||
75 | success := f.fetchTimer.Reset(timeout) | ||
76 | if !success { | ||
77 | f.fetchTimer = time.NewTimer(timeout) | ||
78 | } | ||
79 | |||
80 | return nil | ||
81 | } | ||
82 | |||
83 | func (f *jwksFetcher) Run() { | ||
84 | for { | ||
85 | select { | ||
86 | // Incoming request for a key, return key or nil in no key | ||
87 | case r := <-f.requests: | ||
88 | if v, ok := f.keyMap[r.KeyId]; ok { | ||
89 | r.Response <- &v | ||
90 | } else { | ||
91 | r.Response <- nil | ||
92 | } | ||
93 | case <-f.fetchTimer.C: | ||
94 | f.Fetch() | ||
95 | case <-f.done: | ||
96 | return | ||
97 | } | ||
98 | } | ||
99 | } | ||
100 | |||
101 | func (f *jwksFetcher) Done() { | ||
102 | f.done <- true | ||
103 | } | ||
104 | |||
105 | func (f *jwksFetcher) GetKey(kid string) (*jose.JSONWebKey, error) { | ||
106 | r := &KeyRequest{ | ||
107 | KeyId: kid, | ||
108 | Response: make(chan *jose.JSONWebKey), | ||
109 | } | ||
110 | |||
111 | f.requests <- r | ||
112 | |||
113 | if res := <-r.Response; res == nil { | ||
114 | return nil, errors.Errorf("Key not found for ID") | ||
115 | } else { | ||
116 | return res, nil | ||
117 | } | ||
118 | } | ||
diff --git a/jws_validator.go b/jws_validator.go index e77c026..0b2467f 100644 --- a/jws_validator.go +++ b/jws_validator.go | |||
@@ -1,103 +1,133 @@ | |||
1 | package main | 1 | package main |
2 | 2 | ||
3 | import ( | 3 | import ( |
4 | "crypto/sha256" | 4 | "github.com/pkg/errors" |
5 | "encoding/hex" | ||
6 | "fmt" | ||
7 | "gopkg.in/square/go-jose.v2" | 5 | "gopkg.in/square/go-jose.v2" |
8 | "gopkg.in/square/go-jose.v2/jwt" | 6 | "gopkg.in/square/go-jose.v2/jwt" |
7 | "net/url" | ||
9 | "time" | 8 | "time" |
10 | ) | 9 | ) |
11 | 10 | ||
11 | // TODO | ||
12 | // validate amr claim contains requested acr values (selective_mfa will be just mfa) | ||
13 | // validate acr claim is the same as requested acr_values | ||
14 | // | ||
15 | // acr_values can be mfa or selective_mfa (mfa only for external users) | ||
16 | // mfa amr values: | ||
17 | // pas - password | ||
18 | // otp - OTP code | ||
19 | // u2f - U2F code | ||
20 | // mfa - multi-factor | ||
21 | // hrd - hardware OTP device used | ||
22 | // sft - software OTP device used | ||
23 | |||
12 | type Claims struct { | 24 | type Claims struct { |
13 | Nonce string `json:"nonce,omitempty"` | 25 | Nonce string `json:"nonce,omitempty"` |
14 | jwt.Claims | 26 | jwt.Claims |
15 | } | 27 | } |
16 | 28 | ||
29 | type JWSValidationContext struct { | ||
30 | KeyFetcher JWKSFetcher | ||
31 | Issuer string | ||
32 | ClientId *url.URL | ||
33 | ClockSkew time.Duration | ||
34 | MaxLiftetime time.Duration | ||
35 | } | ||
36 | |||
17 | type JWSValidator interface { | 37 | type JWSValidator interface { |
18 | Validate(string, string) (*Claims, error) | 38 | Validate(string, string) (*Claims, error) |
19 | } | 39 | } |
20 | 40 | ||
21 | type jwsValidator struct { | 41 | type jwsValidator struct { |
22 | algorithms *stringSet | 42 | algorithms *stringSet |
23 | jwks map[string]jose.JSONWebKey | 43 | jwks JWKSFetcher |
24 | issuer string | 44 | issuer string |
25 | clientID string | 45 | clientId *url.URL |
26 | clockSkew time.Duration | 46 | clockSkew time.Duration |
27 | maxLifetime time.Duration | 47 | maxLifetime time.Duration |
28 | } | 48 | } |
29 | 49 | ||
30 | // TODO | 50 | func NewJWSValidator(c *JWSValidationContext) JWSValidator { |
31 | // validate amr claim contains requested acr values (selective_mfa will be just mfa) | ||
32 | // validate acr claim is the same as requested acr_values | ||
33 | func NewJWSValidator(jwks map[string]jose.JSONWebKey, issuer string, client_id string, skew time.Duration, max_life time.Duration) JWSValidator { | ||
34 | return &jwsValidator{ | 51 | return &jwsValidator{ |
35 | algorithms: NewStringSet("PS256", "PS385", "PS512"), | 52 | algorithms: NewStringSet("PS256", "PS385", "PS512"), |
36 | jwks: jwks, | 53 | jwks: c.KeyFetcher, |
37 | issuer: issuer, | 54 | issuer: c.Issuer, |
38 | clientID: client_id, | 55 | clientId: c.ClientId, |
39 | clockSkew: skew, | 56 | clockSkew: c.ClockSkew, |
40 | maxLifetime: max_life, | 57 | maxLifetime: c.MaxLiftetime, |
41 | } | 58 | } |
42 | } | 59 | } |
43 | 60 | ||
44 | func (v *jwsValidator) Validate(j string, nonce string) (*Claims, error) { | 61 | func (v *jwsValidator) Validate(j string, nonce string) (*Claims, error) { |
45 | parsed_jwt, err := jwt.ParseSigned(j) | 62 | parsed, err := jwt.ParseSigned(j) |
63 | if err != nil { | ||
64 | return nil, errors.WithStack(err) | ||
65 | } | ||
66 | |||
67 | if err := v.validateHeaders(parsed.Headers); err != nil { | ||
68 | return nil, errors.WithStack(err) | ||
69 | } | ||
70 | |||
71 | kid := parsed.Headers[0].KeyID | ||
72 | key, err := v.jwks.GetKey(kid) | ||
73 | if err != nil { | ||
74 | return nil, errors.WithStack(err) | ||
75 | } | ||
76 | |||
77 | claims, err := v.validateClaims(parsed, key) | ||
46 | if err != nil { | 78 | if err != nil { |
47 | return nil, err | 79 | return nil, errors.WithStack(err) |
48 | } | 80 | } |
49 | 81 | ||
50 | if len(parsed_jwt.Headers) != 1 { | 82 | if err := v.validateNonce(nonce, claims.Nonce); err != nil { |
51 | return nil, fmt.Errorf("Invalid signature count") | 83 | return nil, errors.WithStack(err) |
52 | } | 84 | } |
53 | 85 | ||
54 | head := parsed_jwt.Headers[0] | 86 | return claims, nil |
87 | } | ||
55 | 88 | ||
56 | if !v.algorithms.Contains(head.Algorithm) { | 89 | func (v *jwsValidator) validateHeaders(h []jose.Header) error { |
57 | return nil, fmt.Errorf("Invalid signature algorithm") | 90 | if len(h) != 1 { |
91 | return errors.Errorf("Invalid signature count") | ||
58 | } | 92 | } |
59 | 93 | ||
60 | if typ, ok := head.ExtraHeaders[jose.HeaderType]; !ok || typ != "JWS" { | 94 | if !v.algorithms.Contains(h[0].Algorithm) { |
61 | return nil, fmt.Errorf("Invalid token type") | 95 | return errors.Errorf("Invalid signature algorithm") |
62 | } | 96 | } |
63 | 97 | ||
64 | key, ok := v.jwks[head.KeyID] | 98 | if typ, ok := h[0].ExtraHeaders[jose.HeaderType]; !ok || typ != "JWS" { |
65 | if !ok { | 99 | return errors.Errorf("Invalid token type") |
66 | return nil, fmt.Errorf("No key found for key id") | ||
67 | } | 100 | } |
68 | 101 | ||
102 | return nil | ||
103 | } | ||
104 | |||
105 | func (v *jwsValidator) validateClaims(j *jwt.JSONWebToken, k *jose.JSONWebKey) (*Claims, error) { | ||
69 | claims := &Claims{} | 106 | claims := &Claims{} |
70 | if err = parsed_jwt.Claims(key, claims); err != nil { | 107 | if err := j.Claims(k, claims); err != nil { |
71 | return nil, err | 108 | return nil, errors.WithStack(err) |
72 | } | 109 | } |
73 | 110 | ||
74 | exp := jwt.Expected{ | 111 | exp := jwt.Expected{ |
75 | Issuer: v.issuer, | 112 | Issuer: v.issuer, |
76 | Audience: jwt.Audience{v.clientID}, | 113 | Audience: jwt.Audience{v.clientId.String()}, |
77 | Time: time.Now(), | 114 | Time: time.Now(), |
78 | } | 115 | } |
79 | 116 | ||
80 | if err := claims.ValidateWithLeeway(exp, v.clockSkew); err != nil { | 117 | if err := claims.ValidateWithLeeway(exp, v.clockSkew); err != nil { |
81 | return nil, err | 118 | return nil, errors.WithStack(err) |
82 | } | 119 | } |
83 | 120 | ||
84 | if claims.IssuedAt.Time().Add(v.maxLifetime).Before(time.Now()) { | 121 | if claims.IssuedAt.Time().Add(v.maxLifetime).Before(time.Now()) { |
85 | return nil, fmt.Errorf("Token exceeded max lifetime") | 122 | return nil, errors.Errorf("Token exceeded max lifetime") |
86 | } | ||
87 | |||
88 | if err = v.validateNonce(nonce, claims.Nonce); err != nil { | ||
89 | return nil, err | ||
90 | } | 123 | } |
91 | 124 | ||
92 | return claims, nil | 125 | return claims, nil |
93 | } | 126 | } |
94 | 127 | ||
95 | func (v *jwsValidator) validateNonce(nonce string, token_nonce string) error { | 128 | func (v *jwsValidator) validateNonce(nonce string, token_nonce string) error { |
96 | s256 := sha256.New() | 129 | if token_nonce != Sha256Hex(nonce) { |
97 | s256.Write([]byte(nonce)) | 130 | return errors.Errorf("Invalid nonce: %s = %q vs %q", nonce, token_nonce, Sha256Hex(nonce)) |
98 | hashed_nonce := hex.EncodeToString(s256.Sum(nil)) | ||
99 | if token_nonce != hashed_nonce { | ||
100 | return fmt.Errorf("Invalid nonce") | ||
101 | } | 131 | } |
102 | 132 | ||
103 | return nil | 133 | return nil |
diff --git a/key_validator.go b/key_validator.go index fe6eb7b..062d78c 100644 --- a/key_validator.go +++ b/key_validator.go | |||
@@ -4,11 +4,13 @@ import ( | |||
4 | "crypto/rsa" | 4 | "crypto/rsa" |
5 | "crypto/x509" | 5 | "crypto/x509" |
6 | "encoding/pem" | 6 | "encoding/pem" |
7 | "fmt" | 7 | "github.com/pkg/errors" |
8 | "gopkg.in/square/go-jose.v2" | 8 | "gopkg.in/square/go-jose.v2" |
9 | "io/ioutil" | 9 | "io/ioutil" |
10 | ) | 10 | ) |
11 | 11 | ||
12 | // TODO: CRL validation | ||
13 | |||
12 | type KeyValidator interface { | 14 | type KeyValidator interface { |
13 | Validate(jose.JSONWebKey) error | 15 | Validate(jose.JSONWebKey) error |
14 | LoadRootPEM(string) error | 16 | LoadRootPEM(string) error |
@@ -31,17 +33,17 @@ func NewKeyValidator(subject string) KeyValidator { | |||
31 | func (v *keyValidator) LoadRootPEM(filename string) error { | 33 | func (v *keyValidator) LoadRootPEM(filename string) error { |
32 | pem_data, err := ioutil.ReadFile(filename) | 34 | pem_data, err := ioutil.ReadFile(filename) |
33 | if err != nil { | 35 | if err != nil { |
34 | return err | 36 | return errors.WithStack(err) |
35 | } | 37 | } |
36 | 38 | ||
37 | pem_block, _ := pem.Decode(pem_data) | 39 | pem_block, _ := pem.Decode(pem_data) |
38 | if pem_block == nil { | 40 | if pem_block == nil { |
39 | return fmt.Errorf("PEM decode failed") | 41 | return errors.Errorf("PEM decode failed") |
40 | } | 42 | } |
41 | 43 | ||
42 | cert, err := x509.ParseCertificate(pem_block.Bytes) | 44 | cert, err := x509.ParseCertificate(pem_block.Bytes) |
43 | if err != nil { | 45 | if err != nil { |
44 | return err | 46 | return errors.WithStack(err) |
45 | } | 47 | } |
46 | 48 | ||
47 | v.roots.AddCert(cert) | 49 | v.roots.AddCert(cert) |
@@ -52,40 +54,40 @@ func (v *keyValidator) LoadRootPEM(filename string) error { | |||
52 | func (v *keyValidator) Validate(key jose.JSONWebKey) error { | 54 | func (v *keyValidator) Validate(key jose.JSONWebKey) error { |
53 | pk, ok := key.Key.(*rsa.PublicKey) | 55 | pk, ok := key.Key.(*rsa.PublicKey) |
54 | if !ok { | 56 | if !ok { |
55 | return fmt.Errorf("Key type is not RSA") | 57 | return errors.Errorf("Key type is not RSA") |
56 | } | 58 | } |
57 | 59 | ||
58 | if !v.algorithms.Contains(key.Algorithm) { | 60 | if !v.algorithms.Contains(key.Algorithm) { |
59 | return fmt.Errorf("Key algorithm is not supported") | 61 | return errors.Errorf("Key algorithm is not supported") |
60 | } | 62 | } |
61 | 63 | ||
62 | cert := key.Certificates[0] | 64 | cert := key.Certificates[0] |
63 | cpk, ok := cert.PublicKey.(*rsa.PublicKey) | 65 | cpk, ok := cert.PublicKey.(*rsa.PublicKey) |
64 | if !ok { | 66 | if !ok { |
65 | return fmt.Errorf("Public key is not RSA") | 67 | return errors.Errorf("Public key is not RSA") |
66 | } | 68 | } |
67 | 69 | ||
68 | if cpk.N.BitLen() < 2048 { | 70 | if cpk.N.BitLen() < 2048 { |
69 | return fmt.Errorf("Key length less than 2048 bits") | 71 | return errors.Errorf("Key length less than 2048 bits") |
70 | } | 72 | } |
71 | 73 | ||
72 | if cert.KeyUsage&x509.KeyUsageDigitalSignature != 1 { | 74 | if cert.KeyUsage&x509.KeyUsageDigitalSignature != 1 { |
73 | return fmt.Errorf("Certificate not valid for digital signatures") | 75 | return errors.Errorf("Certificate not valid for digital signatures") |
74 | } | 76 | } |
75 | 77 | ||
76 | err := v.validateCertificateChain(key.Certificates) | 78 | err := v.validateCertificateChain(key.Certificates) |
77 | if err != nil { | 79 | if err != nil { |
78 | return err | 80 | return errors.WithStack(err) |
79 | } | 81 | } |
80 | 82 | ||
81 | err = v.validateCertificateCRL(cert) | 83 | err = v.validateCertificateCRL(cert) |
82 | if err != nil { | 84 | if err != nil { |
83 | return err | 85 | return errors.WithStack(err) |
84 | } | 86 | } |
85 | 87 | ||
86 | err = v.validatePublicKeyInCertificate(pk, cpk) | 88 | err = v.validatePublicKeyInCertificate(pk, cpk) |
87 | if err != nil { | 89 | if err != nil { |
88 | return err | 90 | return errors.WithStack(err) |
89 | } | 91 | } |
90 | 92 | ||
91 | return nil | 93 | return nil |
@@ -116,15 +118,15 @@ func (v *keyValidator) validateCertificateChain(chain []*x509.Certificate) error | |||
116 | 118 | ||
117 | chains, err := chain[0].Verify(vo) | 119 | chains, err := chain[0].Verify(vo) |
118 | if err != nil { | 120 | if err != nil { |
119 | return err | 121 | return errors.WithStack(err) |
120 | } | 122 | } |
121 | 123 | ||
122 | if len(chains) <= 0 { | 124 | if len(chains) <= 0 { |
123 | return fmt.Errorf("No valid certificate chains found") | 125 | return errors.Errorf("No valid certificate chains found") |
124 | } | 126 | } |
125 | 127 | ||
126 | if chain[0].Subject.CommonName != v.pkiSubject { | 128 | if chain[0].Subject.CommonName != v.pkiSubject { |
127 | return fmt.Errorf("Invalid certificate subject name") | 129 | return errors.Errorf("Invalid certificate subject name") |
128 | } | 130 | } |
129 | 131 | ||
130 | return nil | 132 | return nil |
@@ -133,11 +135,11 @@ func (v *keyValidator) validateCertificateChain(chain []*x509.Certificate) error | |||
133 | // validate first item of x5c matches n and e | 135 | // validate first item of x5c matches n and e |
134 | func (v *keyValidator) validatePublicKeyInCertificate(pk *rsa.PublicKey, cpk *rsa.PublicKey) error { | 136 | func (v *keyValidator) validatePublicKeyInCertificate(pk *rsa.PublicKey, cpk *rsa.PublicKey) error { |
135 | if cpk.E != pk.E { | 137 | if cpk.E != pk.E { |
136 | return fmt.Errorf("E in key and E in cert do not match") | 138 | return errors.Errorf("E in key and E in cert do not match") |
137 | } | 139 | } |
138 | 140 | ||
139 | if pk.N.Cmp(cpk.N) != 0 { | 141 | if pk.N.Cmp(cpk.N) != 0 { |
140 | return fmt.Errorf("N in key and N in cert do not match") | 142 | return errors.Errorf("N in key and N in cert do not match") |
141 | } | 143 | } |
142 | 144 | ||
143 | return nil | 145 | return nil |
@@ -4,53 +4,52 @@ import ( | |||
4 | "context" | 4 | "context" |
5 | "crypto/rand" | 5 | "crypto/rand" |
6 | "encoding/hex" | 6 | "encoding/hex" |
7 | "fmt" | 7 | "flag" |
8 | "gopkg.in/square/go-jose.v2" | 8 | "github.com/golang/glog" |
9 | "log" | 9 | "github.com/gorilla/handlers" |
10 | "github.com/pkg/errors" | ||
10 | "net/http" | 11 | "net/http" |
11 | "net/http/httputil" | 12 | "net/http/httputil" |
12 | "net/url" | 13 | "net/url" |
14 | "os" | ||
15 | "strconv" | ||
13 | "strings" | 16 | "strings" |
14 | "time" | 17 | "time" |
15 | ) | 18 | ) |
16 | 19 | ||
17 | const ( | 20 | const ( |
18 | NONCE_SIZE int = 16 | 21 | NONCE_SIZE = 16 |
19 | TOKEN_COOKIE_NAME string = "sso_token" | 22 | TOKEN_COOKIE_NAME = "sso_token" |
20 | RFP_COOKIE_NAME string = "sso_rfp" | 23 | RFP_COOKIE_NAME = "sso_rfp" |
24 | DEFAULT_CLOCK_SKEW = 5 * time.Minute | ||
25 | DEFAULT_MAX_LIFETIME = 24 * time.Hour | ||
26 | DEFAULT_COOKIE_EXP = 48 * time.Hour | ||
21 | ) | 27 | ) |
22 | 28 | ||
23 | // TODO: Enable https checks in HTTP client | 29 | // TODO: MFA support |
24 | |||
25 | // acr_values can be mfa or selective_mfa (mfa only for external users) | ||
26 | // mfa amr values: | ||
27 | // pas - password | ||
28 | // otp - OTP code | ||
29 | // u2f - U2F code | ||
30 | // mfa - multi-factor | ||
31 | // hrd - hardware OTP device used | ||
32 | // sft - software OTP device used | ||
33 | 30 | ||
34 | type ProxyConfig struct { | 31 | type ProxyConfig struct { |
35 | IDProviderURL string | 32 | IdProviderURL *url.URL |
36 | ClientID string | 33 | IdProviderAuthEndpoint *url.URL |
37 | UpstreamURL string | 34 | ClientId *url.URL |
38 | ListenOn string | 35 | UpstreamURL string |
39 | TrustedCACert string | 36 | ListenOn string |
40 | PKISubject string // TODO: Should be same as IDP w/out scheme and port | 37 | TrustedCACert string |
41 | ClockSkew time.Duration | 38 | ClockSkew time.Duration |
42 | MaxLiftetime time.Duration | 39 | MaxLiftetime time.Duration |
43 | IsOptional bool | 40 | IsOptional bool |
44 | RequestMFA bool | 41 | IsBootstrap bool |
45 | AllowedMFAMethods []string // An OR set | 42 | RequestMFA bool |
46 | RequiredMFAMethods []string // An AND set | 43 | AllowedMFAMethods []string // An OR set |
47 | reverseProxy *httputil.ReverseProxy | 44 | RequiredMFAMethods []string // An AND set |
45 | reverseProxy *httputil.ReverseProxy | ||
46 | jwsValidator JWSValidator | ||
48 | } | 47 | } |
49 | 48 | ||
50 | type IdPConfig struct { | 49 | type IdPConfig struct { |
51 | AuthorizationEndpoint string `json:"authorization_endpoint"` | 50 | AuthorizationEndpoint *JSONURL `json:"authorization_endpoint"` |
51 | JwksUri *JSONURL `json:"jwks_uri"` | ||
52 | Issuer string `json:"issuer"` | 52 | Issuer string `json:"issuer"` |
53 | JwksUri string `json:"jwks_uri"` | ||
54 | GrantTypes []string `json:"grant_types_supported"` | 53 | GrantTypes []string `json:"grant_types_supported"` |
55 | IdTokenSigningAlgs []string `json:"id_token_signing_alg_values_supported"` | 54 | IdTokenSigningAlgs []string `json:"id_token_signing_alg_values_supported"` |
56 | ResponseModes []string `json:"response_modes_supported"` | 55 | ResponseModes []string `json:"response_modes_supported"` |
@@ -59,258 +58,311 @@ type IdPConfig struct { | |||
59 | SubjectTypes []string `json:"subject_types_supported"` | 58 | SubjectTypes []string `json:"subject_types_supported"` |
60 | } | 59 | } |
61 | 60 | ||
62 | // TODO: Optimization to fetch only if expired (per http headers) | 61 | func FetchIdPConfig(h CautiousHTTPClient, u *url.URL) (*IdPConfig, error) { |
63 | func FetchIdPConfig(h CautiousHTTPClient, idp_url string) (*IdPConfig, error) { | 62 | u = URLMustParse(u.String()) |
64 | u, err := url.Parse(idp_url) | ||
65 | if err != nil { | ||
66 | return nil, err | ||
67 | } | ||
68 | u.Path = "/.well-known/openid-configuration" | 63 | u.Path = "/.well-known/openid-configuration" |
69 | 64 | ||
70 | var idpc IdPConfig | 65 | var idpc IdPConfig |
71 | err = h.GetJSON(u.String(), &idpc) | 66 | err := h.GetJSON(u.String(), &idpc) |
72 | if err != nil { | 67 | if err != nil { |
73 | return nil, err | 68 | return nil, errors.WithStack(err) |
74 | } | 69 | } |
75 | 70 | ||
76 | return &idpc, nil | 71 | return &idpc, nil |
77 | } | 72 | } |
78 | 73 | ||
79 | // TODO: Optimization to fetch only if expired (per http headers) | ||
80 | func FetchJWKS(h CautiousHTTPClient, jwks_url string, val KeyValidator) (map[string]jose.JSONWebKey, error) { | ||
81 | var jwks jose.JSONWebKeySet | ||
82 | err := h.GetJSON(jwks_url, &jwks) | ||
83 | if err != nil { | ||
84 | return nil, err | ||
85 | } | ||
86 | |||
87 | keys := make(map[string]jose.JSONWebKey, len(jwks.Keys)) | ||
88 | |||
89 | for _, k := range jwks.Keys { | ||
90 | err = val.Validate(k) | ||
91 | if err == nil { | ||
92 | keys[k.KeyID] = k | ||
93 | } | ||
94 | } | ||
95 | |||
96 | return keys, nil | ||
97 | } | ||
98 | |||
99 | func GenerateNonce() (string, error) { | 74 | func GenerateNonce() (string, error) { |
100 | nonce := make([]byte, NONCE_SIZE) | 75 | nonce := make([]byte, NONCE_SIZE) |
101 | n, err := rand.Read(nonce) | 76 | n, err := rand.Read(nonce) |
102 | if n != NONCE_SIZE || err != nil { | 77 | if n != NONCE_SIZE || err != nil { |
103 | return "", err | 78 | return "", errors.WithStack(err) |
104 | } | 79 | } |
105 | return hex.EncodeToString(nonce), nil | 80 | return hex.EncodeToString(nonce), nil |
106 | } | 81 | } |
107 | 82 | ||
108 | // TODO | 83 | func SetSecureCookie(w http.ResponseWriter, name string, value string, exp time.Duration) { |
109 | // Cookie rules | 84 | http.SetCookie(w, &http.Cookie{ |
110 | // Secure | 85 | Name: name, |
111 | // HttpOnly | 86 | Value: value, |
112 | // Path to / | 87 | Expires: time.Now().Add(exp), |
113 | // Expires to iat in JWT | 88 | HttpOnly: true, |
114 | func SetCookie() { | 89 | Secure: true, |
90 | Path: "/", | ||
91 | }) | ||
115 | } | 92 | } |
116 | 93 | ||
117 | // TODO | 94 | func ExpireCookie(w http.ResponseWriter, name string) { |
118 | func MakeClientID(r *http.Request) string { | 95 | http.SetCookie(w, &http.Cookie{ |
119 | if strings.Contains(r.Host, ":") { | 96 | Name: name, |
120 | return r.Host | 97 | Value: "", |
121 | } | 98 | Expires: time.Now().Add(-1 * time.Hour), |
122 | return "" | 99 | HttpOnly: true, |
100 | Secure: true, | ||
101 | Path: "/", | ||
102 | MaxAge: 0, | ||
103 | }) | ||
123 | } | 104 | } |
124 | 105 | ||
125 | // TODO | 106 | func RedirectToIdP(w http.ResponseWriter, r *http.Request, path string) { |
126 | func RedirectToIDP(w http.ResponseWriter, r *http.Request) { | 107 | ctx := r.Context().Value("ProxyConfig").(*ProxyConfig) |
127 | nonce, _ := GenerateNonce() | ||
128 | _ = nonce | ||
129 | nonceh := "" // SHA256 nonce | ||
130 | 108 | ||
131 | // Set nonce cookie | 109 | nonce, err := GenerateNonce() |
110 | if err != nil { | ||
111 | http.Error(w, "Internal Server Error", http.StatusInternalServerError) | ||
112 | return | ||
113 | } | ||
114 | |||
115 | SetSecureCookie(w, RFP_COOKIE_NAME, nonce, DEFAULT_COOKIE_EXP) | ||
116 | |||
117 | rt := "" | ||
118 | rp := r.URL.Query().Get("redirect_uri") | ||
119 | if rp != "" { | ||
120 | rt = rp | ||
121 | } else { | ||
122 | ru := &url.URL{ | ||
123 | Scheme: "https", | ||
124 | Host: r.Host, | ||
125 | Path: path, | ||
126 | } | ||
127 | rt = ru.String() | ||
128 | } | ||
132 | 129 | ||
133 | req := url.Values{} | 130 | req := url.Values{} |
134 | req.Add("client_id", "") // fqdn + : + port | 131 | req.Add("client_id", ctx.ClientId.String()) |
135 | req.Add("nonce", nonceh) | 132 | req.Add("nonce", Sha256Hex(nonce)) |
136 | req.Add("redirect_uri", "") // Requested URL | 133 | req.Add("redirect_uri", rt) |
137 | req.Add("scope", "openid") | 134 | req.Add("scope", "openid") |
138 | req.Add("response_type", "id_token") | 135 | req.Add("response_type", "id_token") |
136 | |||
137 | u := URLMustParse(ctx.IdProviderAuthEndpoint.String()) | ||
138 | u.RawQuery = req.Encode() | ||
139 | |||
140 | http.Redirect(w, r, u.String(), http.StatusFound) | ||
139 | } | 141 | } |
140 | 142 | ||
141 | // TODO: Remove id_token from URL, set cookie and redirect user to requested URL | ||
142 | func SetTokenCookieAndRedirect(w http.ResponseWriter, r *http.Request, token string) { | 143 | func SetTokenCookieAndRedirect(w http.ResponseWriter, r *http.Request, token string) { |
143 | } | 144 | SetSecureCookie(w, TOKEN_COOKIE_NAME, token, DEFAULT_COOKIE_EXP) |
144 | 145 | ||
145 | // TODO | 146 | q := r.URL.Query() |
146 | func ValidateJWT(jwt, rfp string) bool { | 147 | q.Del("id_token") |
147 | return true | 148 | |
148 | } | 149 | u := URLMustParse(r.URL.String()) |
150 | u.RawQuery = q.Encode() | ||
149 | 151 | ||
150 | // TODO | 152 | http.Redirect(w, r, u.String(), http.StatusFound) |
151 | func GetJWTSubject(jwt string) string { | ||
152 | return "" | ||
153 | } | 153 | } |
154 | 154 | ||
155 | func RequestHasForwardedUser(w http.ResponseWriter, r *http.Request) bool { | 155 | func RequestHasForwardedUser(w http.ResponseWriter, r *http.Request) bool { |
156 | if _, ok := r.Header["X-Forwarded-User"]; ok { | 156 | _, ok := r.Header["X-Forwarded-User"] |
157 | log.Printf("ERROR: Request contains X-Forwarded-For header") | 157 | return ok |
158 | http.Error(w, "Bad Request", http.StatusBadRequest) | ||
159 | return true | ||
160 | } else { | ||
161 | return false | ||
162 | } | ||
163 | } | 158 | } |
164 | 159 | ||
165 | func RequestIsOverSecureChannel(w http.ResponseWriter, r *http.Request) bool { | 160 | func RequestIsOverSecureChannel(w http.ResponseWriter, r *http.Request) bool { |
166 | https, ok := r.Header["X-Forwarded-Proto"] | 161 | https, ok := r.Header["X-Forwarded-Proto"] |
167 | if !ok || len(https) != 1 { | 162 | if !ok || len(https) != 1 { |
168 | log.Printf("ERROR: Request does not contain X-Forwarded-Proto header") | 163 | glog.V(1).Infoln("Request does not contain X-Forwarded-Proto header") |
169 | http.Error(w, "Bad Request", http.StatusBadRequest) | ||
170 | return false | 164 | return false |
171 | } | 165 | } |
172 | 166 | ||
173 | if !CompareUpper(https[0], "HTTPS") { | 167 | if !CompareUpper(https[0], "HTTPS") { |
174 | log.Printf("ERROR: Request is not over HTTPS") | 168 | glog.V(1).Infoln("Request is not over HTTPS") |
175 | http.Error(w, "Bad Request", http.StatusBadRequest) | ||
176 | return false | 169 | return false |
177 | } | 170 | } |
178 | 171 | ||
179 | return true | 172 | return true |
180 | } | 173 | } |
181 | 174 | ||
182 | // TODO | ||
183 | // - Validate Hostname header == known hostname | ||
184 | func AuthProxyController(w http.ResponseWriter, r *http.Request) { | 175 | func AuthProxyController(w http.ResponseWriter, r *http.Request) { |
185 | proxy := r.Context().Value("ProxyConfig").(*ProxyConfig).reverseProxy | 176 | ctx := r.Context().Value("ProxyConfig").(*ProxyConfig) |
186 | 177 | ||
187 | // Order matters in these checks! | 178 | // Order matters in these checks! |
188 | if RequestHasForwardedUser(w, r) { | 179 | if RequestHasForwardedUser(w, r) { |
180 | glog.Errorln("Request already has X-Forwarded-User") | ||
181 | http.Error(w, "Bad Request", http.StatusBadRequest) | ||
189 | return | 182 | return |
190 | } | 183 | } |
191 | 184 | ||
192 | if !RequestIsOverSecureChannel(w, r) { | 185 | if !RequestIsOverSecureChannel(w, r) { |
186 | http.Error(w, "Bad Request", http.StatusBadRequest) | ||
193 | return | 187 | return |
194 | } | 188 | } |
195 | 189 | ||
196 | if CompareUpper(r.Method, "OPTIONS") { | 190 | if CompareUpper(r.Method, "OPTIONS") { |
197 | proxy.ServeHTTP(w, r) | 191 | ctx.reverseProxy.ServeHTTP(w, r) |
198 | return | 192 | return |
199 | } | 193 | } |
200 | 194 | ||
201 | rfpc, err := r.Cookie(RFP_COOKIE_NAME) | 195 | rfpc, err := r.Cookie(RFP_COOKIE_NAME) |
202 | if err != nil { | 196 | if err != nil { |
203 | log.Printf("ERROR: No rfp cookie") | 197 | glog.V(1).Infoln("No rfp cookie") |
204 | RedirectToIDP(w, r) | 198 | if ctx.IsOptional { |
205 | return | 199 | ctx.reverseProxy.ServeHTTP(w, r) |
200 | return | ||
201 | } else { | ||
202 | RedirectToIdP(w, r, r.URL.Path) | ||
203 | return | ||
204 | } | ||
206 | } | 205 | } |
207 | 206 | ||
208 | token := r.URL.Query().Get("id_token") | 207 | if token := r.URL.Query().Get("id_token"); token != "" { |
209 | if token != "" && ValidateJWT(rfpc.Value, token) { | 208 | if _, err := ctx.jwsValidator.Validate(token, rfpc.Value); err == nil { |
210 | SetTokenCookieAndRedirect(w, r, token) | 209 | SetTokenCookieAndRedirect(w, r, token) |
211 | return | 210 | return |
211 | } else { | ||
212 | glog.V(1).Infof("Querystring id_token invalid: %s", err) | ||
213 | } | ||
212 | } | 214 | } |
213 | 215 | ||
214 | tokenc, err := r.Cookie(TOKEN_COOKIE_NAME) | 216 | tokenc, err := r.Cookie(TOKEN_COOKIE_NAME) |
215 | if err != nil { | 217 | if err != nil { |
216 | log.Printf("ERROR: No token cookie") | 218 | glog.V(1).Infoln("No token cookie") |
217 | RedirectToIDP(w, r) | 219 | if ctx.IsOptional { |
218 | return | 220 | ctx.reverseProxy.ServeHTTP(w, r) |
221 | return | ||
222 | } else { | ||
223 | RedirectToIdP(w, r, r.URL.Path) | ||
224 | return | ||
225 | } | ||
219 | } | 226 | } |
220 | 227 | ||
221 | if !ValidateJWT(tokenc.Value, rfpc.Value) { | 228 | claims, err := ctx.jwsValidator.Validate(tokenc.Value, rfpc.Value) |
222 | log.Printf("ERROR: Token is invalid") | 229 | if err != nil { |
223 | RedirectToIDP(w, r) | 230 | glog.Errorln("Token is invalid", err) |
224 | return | 231 | if ctx.IsOptional { |
232 | ctx.reverseProxy.ServeHTTP(w, r) | ||
233 | return | ||
234 | } else { | ||
235 | RedirectToIdP(w, r, r.URL.Path) | ||
236 | return | ||
237 | } | ||
225 | } | 238 | } |
226 | 239 | ||
227 | r.Header["X-Forwarded-User"] = []string{GetJWTSubject(tokenc.Value)} | 240 | r.Header["X-Forwarded-User"] = []string{claims.Subject} |
241 | r.Header["X-Forwarded-Token-Expires"] = []string{strconv.FormatInt(int64(claims.Expiry), 10)} | ||
242 | |||
243 | age := time.Since(claims.IssuedAt.Time()).Minutes() | ||
244 | r.Header["X-Forwarded-Token-Age"] = []string{strconv.FormatInt(int64(age), 10)} | ||
228 | 245 | ||
229 | proxy.ServeHTTP(w, r) | 246 | ctx.reverseProxy.ServeHTTP(w, r) |
230 | } | 247 | } |
231 | 248 | ||
232 | // Remove token and rfp cookies and redirect user to root of domain | ||
233 | func LogoutController(w http.ResponseWriter, r *http.Request) { | 249 | func LogoutController(w http.ResponseWriter, r *http.Request) { |
234 | http.SetCookie(w, &http.Cookie{ | 250 | ExpireCookie(w, RFP_COOKIE_NAME) |
235 | Name: TOKEN_COOKIE_NAME, | 251 | ExpireCookie(w, TOKEN_COOKIE_NAME) |
236 | Value: "", | ||
237 | MaxAge: 0, | ||
238 | }) | ||
239 | |||
240 | http.SetCookie(w, &http.Cookie{ | ||
241 | Name: TOKEN_COOKIE_NAME, | ||
242 | Value: "", | ||
243 | MaxAge: 0, | ||
244 | }) | ||
245 | |||
246 | http.Redirect(w, r, "/", http.StatusFound) | 252 | http.Redirect(w, r, "/", http.StatusFound) |
247 | } | 253 | } |
248 | 254 | ||
249 | // TODO | ||
250 | func LoginController(w http.ResponseWriter, r *http.Request) { | ||
251 | } | ||
252 | |||
253 | // TODO | ||
254 | // Optional login allows for applications that can operate in anonymous mode or | 255 | // Optional login allows for applications that can operate in anonymous mode or |
255 | // authenticated mode. When in anonmyous mode the request is proxied through | 256 | // authenticated mode. When in anonmyous mode the request is proxied through |
256 | // without an X-Forwarded-User header. Upstream servers should either expose or | 257 | // without an X-Forwarded-User header. Upstream servers should either expose or |
257 | // map a URL for /.oidc/login to allow users to login. On successful login the | 258 | // map a URL for /.oidc/login to allow users to login. On successful login the |
258 | // user will be redirected back to the main page for the site (/) | 259 | // user will be redirected back to the main page for the site (/) |
259 | func parseConfig() *ProxyConfig { | 260 | func parseConfig() (*ProxyConfig, error) { |
260 | return &ProxyConfig{ | 261 | c := &ProxyConfig{} |
261 | IDProviderURL: "http://mcrute-virt:9993", | 262 | |
262 | ClientID: "test.crute.me:443", | 263 | idpu := flag.String("idp", "", "URL for ID provider") |
263 | UpstreamURL: "http://localhost:9991/", | 264 | mfam := flag.String("allow-mfa-methods", "", "Comma seperated list of allowed mfa methods") |
264 | ListenOn: ":9992", | 265 | rmfa := flag.String("require-mfa-methods", "", "Comma seperated list of required mfa methods") |
265 | TrustedCACert: "/home/mcrute/oidc_project/test_ca/ca_cert.pem", | 266 | cids := flag.String("client-id", "", "Client ID for proxy with IdP") |
266 | IsOptional: false, | 267 | |
267 | PKISubject: "Crute OpenID Signing 1", | 268 | flag.BoolVar(&c.IsOptional, "optional", false, "Allow proxying of unauthenticated calls") |
268 | MaxLiftetime: 24 * time.Hour, | 269 | flag.BoolVar(&c.IsBootstrap, "bootstrap", false, "Allow running a proxy for the IdP itself") |
269 | ClockSkew: 5 * time.Minute, | 270 | flag.BoolVar(&c.RequestMFA, "mfa", false, "Request user MFA authentication from IdP") |
271 | |||
272 | flag.DurationVar(&c.MaxLiftetime, "max-lifetime", DEFAULT_MAX_LIFETIME, "Maximum allowed time from token issuance") | ||
273 | flag.DurationVar(&c.ClockSkew, "clock-skew", DEFAULT_CLOCK_SKEW, "Allowable IdP clock skew relative to proxy") | ||
274 | |||
275 | flag.StringVar(&c.UpstreamURL, "upstream", "", "URL of upstream service for which to proxy") | ||
276 | flag.StringVar(&c.ListenOn, "listen", ":9992", "Optional port and ip on which to listen") | ||
277 | flag.StringVar(&c.TrustedCACert, "ca", "", "Path to trusted CA certificate") | ||
278 | |||
279 | flag.Parse() | ||
280 | |||
281 | c.AllowedMFAMethods = strings.Split(*mfam, ",") | ||
282 | c.RequiredMFAMethods = strings.Split(*rmfa, ",") | ||
283 | |||
284 | if c.IsBootstrap { | ||
285 | c.IsOptional = true | ||
286 | } | ||
287 | |||
288 | if _, err := os.Stat(c.TrustedCACert); os.IsNotExist(err) { | ||
289 | return nil, errors.Errorf("CA certificate does not exist") | ||
290 | } | ||
291 | |||
292 | if cids == nil { | ||
293 | return nil, errors.Errorf("Client ID is required") | ||
294 | } | ||
295 | |||
296 | if client_id, err := url.Parse(*cids); err != nil || client_id.Host == "" { | ||
297 | return nil, errors.Errorf("Invalid client ID") | ||
298 | } else { | ||
299 | c.ClientId = client_id | ||
300 | } | ||
301 | |||
302 | if c.UpstreamURL == "" { | ||
303 | return nil, errors.Errorf("Upstream URL is required") | ||
304 | } | ||
305 | |||
306 | if idpu == nil { | ||
307 | return nil, errors.Errorf("IDP url is required") | ||
270 | } | 308 | } |
309 | |||
310 | if u, err := url.Parse(*idpu); err != nil { | ||
311 | return nil, errors.WithStack(err) | ||
312 | } else { | ||
313 | c.IdProviderURL = u | ||
314 | |||
315 | if h := HostFromURL(u.String()); c.IsBootstrap && (h != "localhost" && h != "127.0.0.1") { | ||
316 | return nil, errors.Errorf("IdP must be set to localhost for bootstrap") | ||
317 | } | ||
318 | } | ||
319 | |||
320 | return c, nil | ||
271 | } | 321 | } |
272 | 322 | ||
273 | func main() { | 323 | func main() { |
274 | cfg := parseConfig() | 324 | cfg, err := parseConfig() |
275 | h := NewCautiousHTTPClient() | 325 | if err != nil { |
276 | 326 | glog.Fatalln("ParseConfig", err) | |
277 | v := NewKeyValidator(cfg.PKISubject) | 327 | return |
278 | v.LoadRootPEM(cfg.TrustedCACert) | 328 | } |
279 | 329 | ||
280 | idpc, err := FetchIdPConfig(h, cfg.IDProviderURL) | 330 | hidp, err := NewCautiousHTTPClient(cfg.IsBootstrap) |
281 | if err != nil { | 331 | if err != nil { |
282 | fmt.Printf("%s\n", err) | 332 | glog.Fatalln("Error building http client", err) |
283 | return | 333 | return |
284 | } | 334 | } |
285 | 335 | ||
286 | jwks, err := FetchJWKS(h, idpc.JwksUri, v) | 336 | idpc, err := FetchIdPConfig(hidp, cfg.IdProviderURL) |
287 | if err != nil { | 337 | if err != nil { |
288 | fmt.Printf("%s\n", err) | 338 | glog.Fatalln("FetchIdPConfig:", err) |
289 | return | 339 | return |
290 | } | 340 | } |
291 | 341 | ||
292 | jv := NewJWSValidator(jwks, idpc.Issuer, cfg.ClientID, cfg.ClockSkew, cfg.MaxLiftetime) | 342 | cfg.IdProviderAuthEndpoint = idpc.AuthorizationEndpoint.AsURL() |
293 | 343 | ||
294 | nonce := "ofspmfjuvoswhhde" | 344 | h, err := NewCautiousHTTPClient(false) |
295 | raw_jwt := "eyJ0eXAiOiJKV1MiLCJhbGciOiJQUzI1NiIsImtpZCI6IjEifQ.eyJub25jZSI6IjM0MjlhMjAyYzU4ZDkyYjQwNjNjOWM4MWM2MjQyNGRlNzBkMmIzZDQ4MmVlNDFhOTdjYmNhZjEwZDk5MWFiOTMiLCJpc3MiOiJpZHAuY3J1dGUubWU6NDQzIiwiaWF0IjoxNTA0NTc2Mzc0LCJuYmYiOjE1MDQ1NzYzNzQsImV4cCI6MTUwNDY2Mjc3NCwic3ViIjoibWNydXRlIiwiYXVkIjoidGVzdC5jcnV0ZS5tZTo0NDMifQ.iizlNfY1Vg7d-XRmgyYuhpNkNrOGaT9OOgO0HdjBozOWMvKzBTtATbIfoWOrNH6DiFY1as8uy3I1Pxnkrb8Ti8_cLDQeLxOv9klAbnebeuPI_wtZ0iwSUnSWaYzN6I6sqcEjHX3fibFvAQhO5dNDzSwONjw4AvcdpZKh579FO1sAvIw-1DmMyPSUun7rbC0Kf1Jtdlr3q7tOp3wdI_erkstxCNPwyuv7X1J7uetsu0BeJS25C2DxeB03BPEIUoo_C1xvcqikfSLLpoFcyToYiS-R9o-WpRjGid_yug65J5ALn2aM3vhe9rRbydKVm_omGL8-Etj06zbqM0Y6OrJUgA" | ||
296 | claims, err := jv.Validate(raw_jwt, nonce) | ||
297 | if err != nil { | 345 | if err != nil { |
298 | fmt.Printf("Error validating: %s\n", err) | 346 | glog.Fatalln("Error building http client", err) |
299 | return | 347 | return |
300 | } | 348 | } |
301 | 349 | ||
302 | fmt.Printf("Valid JWT for: %+v\n", claims.Subject) | 350 | kf := NewJWKSFetcher(h, idpc.JwksUri.AsURL(), idpc.Issuer, cfg.TrustedCACert) |
303 | 351 | ||
304 | return | 352 | cfg.jwsValidator = NewJWSValidator(&JWSValidationContext{ |
353 | KeyFetcher: kf, | ||
354 | Issuer: idpc.Issuer, | ||
355 | ClientId: cfg.ClientId, | ||
356 | ClockSkew: cfg.ClockSkew, | ||
357 | MaxLiftetime: cfg.MaxLiftetime, | ||
358 | }) | ||
305 | 359 | ||
306 | cfg.reverseProxy = httputil.NewSingleHostReverseProxy(URLMustParse(cfg.UpstreamURL)) | 360 | cfg.reverseProxy = httputil.NewSingleHostReverseProxy(URLMustParse(cfg.UpstreamURL)) |
307 | 361 | ||
308 | if cfg.IsOptional { | 362 | http.HandleFunc("/.oidc/login", func(w http.ResponseWriter, r *http.Request) { |
309 | http.HandleFunc("/.oidc/login", func(w http.ResponseWriter, r *http.Request) { | 363 | RedirectToIdP(w, |
310 | LoginController(w, | 364 | r.WithContext(context.WithValue(r.Context(), "ProxyConfig", cfg)), "/") |
311 | r.WithContext(context.WithValue(r.Context(), "ProxyConfig", cfg))) | 365 | }) |
312 | }) | ||
313 | } | ||
314 | 366 | ||
315 | http.HandleFunc("/.oidc/logout", func(w http.ResponseWriter, r *http.Request) { | 367 | http.HandleFunc("/.oidc/logout", func(w http.ResponseWriter, r *http.Request) { |
316 | LogoutController(w, | 368 | LogoutController(w, |
@@ -322,5 +374,14 @@ func main() { | |||
322 | r.WithContext(context.WithValue(r.Context(), "ProxyConfig", cfg))) | 374 | r.WithContext(context.WithValue(r.Context(), "ProxyConfig", cfg))) |
323 | }) | 375 | }) |
324 | 376 | ||
325 | log.Fatal(http.ListenAndServe(cfg.ListenOn, nil)) | 377 | go http.ListenAndServe(cfg.ListenOn, |
378 | handlers.LoggingHandler(os.Stdout, http.DefaultServeMux)) | ||
379 | |||
380 | // This has to happen last in-case we're boostrapping a proxy for the IdP itself | ||
381 | if err := kf.Fetch(); err != nil { | ||
382 | glog.Fatalln("FetchJWKS:", err) | ||
383 | return | ||
384 | } else { | ||
385 | kf.Run() | ||
386 | } | ||
326 | } | 387 | } |
Binary files differ | |||
@@ -1,6 +1,8 @@ | |||
1 | package main | 1 | package main |
2 | 2 | ||
3 | import ( | 3 | import ( |
4 | "crypto/sha256" | ||
5 | "encoding/hex" | ||
4 | "net/url" | 6 | "net/url" |
5 | "strings" | 7 | "strings" |
6 | ) | 8 | ) |
@@ -41,3 +43,19 @@ func URLMustParse(u string) *url.URL { | |||
41 | func CompareUpper(lhs, rhs string) bool { | 43 | func CompareUpper(lhs, rhs string) bool { |
42 | return strings.ToUpper(lhs) == strings.ToUpper(rhs) | 44 | return strings.ToUpper(lhs) == strings.ToUpper(rhs) |
43 | } | 45 | } |
46 | |||
47 | func HostFromURL(u string) string { | ||
48 | o, err := url.Parse(u) | ||
49 | if err != nil { | ||
50 | return "" | ||
51 | } | ||
52 | |||
53 | h := strings.Split(o.Host, ":") | ||
54 | return h[0] | ||
55 | } | ||
56 | |||
57 | func Sha256Hex(v string) string { | ||
58 | s256 := sha256.New() | ||
59 | s256.Write([]byte(v)) | ||
60 | return hex.EncodeToString(s256.Sum(nil)) | ||
61 | } | ||