Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wip: refreshable *tls.Config support for CGR clients #725

Draft
wants to merge 2 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions conjure-go-client/httpclient/client_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ package httpclient

import (
"context"
"crypto/tls"
"fmt"
"net/http"
"time"
Expand Down Expand Up @@ -70,10 +69,12 @@ type clientBuilder struct {
}

type httpClientBuilder struct {
ServiceName refreshable.String
Timeout refreshable.Duration
DialerParams refreshingclient.RefreshableDialerParams
TLSConfig *tls.Config // If unset, config in TransportParams will be used.
ServiceName refreshable.String
Timeout refreshable.Duration
DialerParams refreshingclient.RefreshableDialerParams
// TLSConfig supplies the *tls.Config for the underlying transport to use.
// If unset, config in TransportParams will be used.
TLSConfig refreshingclient.RefreshableTLSConf
TransportParams refreshingclient.RefreshableTransportParams
Middlewares []Middleware

Expand All @@ -97,11 +98,11 @@ func (b *httpClientBuilder) Build(ctx context.Context, params ...HTTPClientParam
}
}

var tlsProvider refreshingclient.TLSProvider
var tlsProvider refreshingclient.RefreshableTLSConf
if b.TLSConfig != nil {
tlsProvider = refreshingclient.NewStaticTLSConfigProvider(b.TLSConfig)
tlsProvider = b.TLSConfig
} else {
refreshableProvider, err := refreshingclient.NewRefreshableTLSConfig(ctx, b.TransportParams.TLS())
refreshableProvider, err := refreshingclient.NewRefreshableTLSConfigFromParams(ctx, b.TransportParams.TLS())
if err != nil {
return nil, err
}
Expand Down
28 changes: 26 additions & 2 deletions conjure-go-client/httpclient/client_params.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,27 @@ func WithTLSConfig(conf *tls.Config) ClientOrHTTPClientParam {
if conf == nil {
b.TLSConfig = nil
} else {
b.TLSConfig = conf.Clone()
b.TLSConfig = refreshingclient.NewStaticTLSConfigProvider(conf.Clone())
}
return nil
})
}

// WithRefreshableTLSConfig allows a user to pass in a refreshable *tls.Config.
// Note that it is undesirable from a performance perspective to pass in a *refreshable.DefaultRefreshable
// here, as that takes a dependency on reflect.DeepEqual. Updates to the underlying *tls.Config cause
// a complete transport refresh, which is a relatively expensive operation.
// Clients are responsible for passing in a refreshable which fulfills their performance requirements.
func WithRefreshableTLSConfig(r refreshable.Refreshable) ClientOrHTTPClientParam {
return clientOrHTTPClientParamFunc(func(b *httpClientBuilder) error {
var err error
if r == nil {
b.TLSConfig = nil
} else {
b.TLSConfig, err = refreshingclient.NewRefreshableTLSConfigFromRefreshable(r)
if err != nil {
return err
}
}
return nil
})
Expand All @@ -366,7 +386,11 @@ func WithTLSConfig(conf *tls.Config) ClientOrHTTPClientParam {
func WithTLSInsecureSkipVerify() ClientOrHTTPClientParam {
return clientOrHTTPClientParamFunc(func(b *httpClientBuilder) error {
if b.TLSConfig != nil {
b.TLSConfig.InsecureSkipVerify = true
b.TLSConfig = refreshingclient.ConfigureTLSConfig(b.TLSConfig, func(conf *tls.Config) *tls.Config {
conf = conf.Clone()
conf.InsecureSkipVerify = true
return conf
})
}
b.TransportParams = refreshingclient.ConfigureTransport(b.TransportParams, func(p refreshingclient.TransportParams) refreshingclient.TransportParams {
p.TLS.InsecureSkipVerify = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,64 @@ package refreshingclient
import (
"context"
"crypto/tls"
"errors"
"sync"
"sync/atomic"

"github.com/palantir/pkg/refreshable"
"github.com/palantir/pkg/tlsconfig"
werror "github.com/palantir/witchcraft-go-error"
"github.com/palantir/witchcraft-go-logging/wlog/svclog/svc1log"
)

type RefreshableTLSConf interface {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Spell out RefreshableTLSConfig to match tls.Config

GetTLSConfig(ctx context.Context) *tls.Config
SubscribeToTLSConfig(consumer func(*tls.Config)) (unsubscribe func())
}

var _ RefreshableTLSConf = (*MappedRefreshableTLSConfig)(nil)

func ConfigureTLSConfig(r RefreshableTLSConf, mapFn func(conf *tls.Config) *tls.Config) RefreshableTLSConf {
var m MappedRefreshableTLSConfig
r.SubscribeToTLSConfig(func(c *tls.Config) {
m.update(mapFn(c))
})
return &m
}

type MappedRefreshableTLSConfig struct {
conf atomic.Pointer[tls.Config]

mu sync.Mutex // protects subscribers
subscribers []*func(*tls.Config)
}

// GetTLSConfig implements RefreshableTLSConf.
func (m *MappedRefreshableTLSConfig) GetTLSConfig(ctx context.Context) *tls.Config {
return m.conf.Load()
}

func (m *MappedRefreshableTLSConfig) update(conf *tls.Config) {
m.conf.Store(conf)

m.mu.Lock()
defer m.mu.Unlock()
for _, sub := range m.subscribers {
(*sub)(conf)
}
}

// SubscribeToTLSConfig implements RefreshableTLSConf.
func (m *MappedRefreshableTLSConfig) SubscribeToTLSConfig(consumer func(*tls.Config)) (unsubscribe func()) {
m.mu.Lock()
defer m.mu.Unlock()

consumerFnPtr := &consumer
m.subscribers = append(m.subscribers, consumerFnPtr)
// TODO(smenon): implement unsubcribe
return func() {}
}

// TLSParams contains the parameters needed to build a *tls.Config.
// Its fields must all be compatible with reflect.DeepEqual.
type TLSParams struct {
Expand All @@ -40,45 +91,74 @@ type TLSProvider interface {
// StaticTLSConfigProvider is a TLSProvider that always returns the same *tls.Config.
type StaticTLSConfigProvider tls.Config

func NewStaticTLSConfigProvider(tlsConfig *tls.Config) *StaticTLSConfigProvider {
func NewStaticTLSConfigProvider(tlsConfig *tls.Config) RefreshableTLSConf {
return (*StaticTLSConfigProvider)(tlsConfig)
}

func (p *StaticTLSConfigProvider) GetTLSConfig(context.Context) *tls.Config {
return (*tls.Config)(p)
}

type RefreshableTLSConfig struct {
// SubscribeToTLSConfig implements RefreshableTLSConf.
func (p *StaticTLSConfigProvider) SubscribeToTLSConfig(consumer func(*tls.Config)) (unsubscribe func()) {
return nil
}

type WrappedRefreshableTLSConfig struct {
r *refreshable.ValidatingRefreshable // contains *tls.Config
}

// NewRefreshableTLSConfig evaluates the provided TLSParams and returns a RefreshableTLSConfig that will update the
// NewRefreshableTLSConfigFromParams evaluates the provided TLSParams and returns a RefreshableTLSConfig that will update the
// underlying *tls.Config when the TLSParams change.
// IF the initial TLSParams are invalid, NewRefreshableTLSConfig will return an error.
// IF the initial TLSParams are invalid, NewRefreshableTLSConfigFromParams will return an error.
// If the updated TLSParams are invalid, the RefreshableTLSConfig will continue to use the previous value and log the error.
//
// N.B. This subscription only fires when the paths are updated, not when the contents of the files are updated.
// We could consider adding a file refreshable to watch the key and cert files.
func NewRefreshableTLSConfig(ctx context.Context, params RefreshableTLSParams) (TLSProvider, error) {
func NewRefreshableTLSConfigFromParams(ctx context.Context, params RefreshableTLSParams) (RefreshableTLSConf, error) {
r, err := refreshable.NewMapValidatingRefreshable(params, func(i interface{}) (interface{}, error) {
return NewTLSConfig(ctx, i.(TLSParams))
})
if err != nil {
return nil, werror.WrapWithContextParams(ctx, err, "failed to build RefreshableTLSConfig")
}
return RefreshableTLSConfig{r: r}, nil
return WrappedRefreshableTLSConfig{r: r}, nil
}

func NewRefreshableTLSConfigFromRefreshable(r refreshable.Refreshable) (RefreshableTLSConf, error) {
validating, err := refreshable.NewValidatingRefreshable(r, func(i interface{}) error {
_, ok := r.Current().(*tls.Config)
if !ok {
// TODO(smenon): proper error msg.
return errors.New("invalid type for refreshable")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generally we are ok with panicking if the type of a refreshable is incorrect (as it's programmer error and should easily get caught the first time the code runs) so I don't think you need NewValidatingRefreshable here

}
return nil
})
if err != nil {
return nil, werror.Wrap(err, "failed to build RefreshableTLSConfig")
}
return WrappedRefreshableTLSConfig{
r: validating,
}, nil
}

// GetTLSConfig returns the most recent valid *tls.Config.
// If the last refreshable update resulted in an error, that error is logged and
// the previous value is returned.
func (r RefreshableTLSConfig) GetTLSConfig(ctx context.Context) *tls.Config {
func (r WrappedRefreshableTLSConfig) GetTLSConfig(ctx context.Context) *tls.Config {
if err := r.r.LastValidateErr(); err != nil {
svc1log.FromContext(ctx).Warn("Invalid TLS config. Using previous value.", svc1log.Stacktrace(err))
}
return r.r.Current().(*tls.Config)
}

// SubscribeToTLSConfig implements RefreshableTLSConf.
func (r WrappedRefreshableTLSConfig) SubscribeToTLSConfig(consumer func(*tls.Config)) (unsubscribe func()) {
return r.r.Subscribe(func(i interface{}) {
consumer(i.(*tls.Config))
})
}

// NewTLSConfig returns a *tls.Config built from the provided TLSParams.
func NewTLSConfig(ctx context.Context, p TLSParams) (*tls.Config, error) {
var tlsParams []tlsconfig.ClientParam
Expand Down
34 changes: 21 additions & 13 deletions conjure-go-client/httpclient/internal/refreshingclient/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@ package refreshingclient

import (
"context"
"crypto/tls"
"net/http"
"net/url"
"sync/atomic"
"time"

"github.com/palantir/pkg/refreshable"
"github.com/palantir/witchcraft-go-logging/wlog/svclog/svc1log"
"golang.org/x/net/http2"
)
Expand All @@ -42,12 +43,21 @@ type TransportParams struct {
TLS TLSParams
}

func NewRefreshableTransport(ctx context.Context, p RefreshableTransportParams, tlsProvider TLSProvider, dialer ContextDialer) http.RoundTripper {
return &RefreshableTransport{
Refreshable: p.MapTransportParams(func(p TransportParams) interface{} {
return newTransport(ctx, p, tlsProvider, dialer)
}),
}
func NewRefreshableTransport(ctx context.Context, p RefreshableTransportParams, t RefreshableTLSConf, dialer ContextDialer) http.RoundTripper {
var refreshingTransport RefreshableTransport

// initialize the transport the first time.
refreshingTransport.Update(ctx, p.CurrentTransportParams(), t.GetTLSConfig(ctx), dialer)

// also subscribe to updates on transport params and the tls provider.
p.SubscribeToTransportParams(func(tp TransportParams) {
refreshingTransport.Update(ctx, tp, t.GetTLSConfig(ctx), dialer)
})
t.SubscribeToTLSConfig(func(conf *tls.Config) {
refreshingTransport.Update(ctx, p.CurrentTransportParams(), conf, dialer)
})

return &refreshingTransport
}

// ConfigureTransport accepts a mapping function which will be applied to the params value as it is evaluated.
Expand All @@ -61,14 +71,14 @@ func ConfigureTransport(r RefreshableTransportParams, mapFn func(p TransportPara
// RefreshableTransport implements http.RoundTripper backed by a refreshable *http.Transport.
// The transport and internal dialer are each rebuilt when any of their respective parameters are updated.
type RefreshableTransport struct {
refreshable.Refreshable // contains *http.Transport
t atomic.Pointer[http.Transport]
}

func (r *RefreshableTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return r.Current().(*http.Transport).RoundTrip(req)
return r.t.Load().RoundTrip(req)
}

func newTransport(ctx context.Context, p TransportParams, tlsProvider TLSProvider, dialer ContextDialer) *http.Transport {
func (r *RefreshableTransport) Update(ctx context.Context, p TransportParams, tlsConfig *tls.Config, dialer ContextDialer) {
svc1log.FromContext(ctx).Debug("Reconstructing HTTP Transport")

var transportProxy func(*http.Request) (*url.URL, error)
Expand All @@ -78,7 +88,6 @@ func newTransport(ctx context.Context, p TransportParams, tlsProvider TLSProvide
transportProxy = http.ProxyFromEnvironment
}

tlsConfig := tlsProvider.GetTLSConfig(ctx)
transport := &http.Transport{
Proxy: transportProxy,
DialContext: dialer.DialContext,
Expand Down Expand Up @@ -115,6 +124,5 @@ func newTransport(ctx context.Context, p TransportParams, tlsProvider TLSProvide
http2Transport.PingTimeout = p.HTTP2PingTimeout
}
}

return transport
r.t.Store(transport)
}