From dea2289ce89b23c2665f87114c8773f35edde969 Mon Sep 17 00:00:00 2001 From: Sid Menon Date: Wed, 27 Nov 2024 17:40:26 -0500 Subject: [PATCH 1/2] (1/n): no-op - refreshable tls --- .../httpclient/client_builder.go | 17 +++-- conjure-go-client/httpclient/client_params.go | 8 +- .../internal/refreshingclient/tlsconfig.go | 76 +++++++++++++++++-- .../internal/refreshingclient/transport.go | 34 +++++---- 4 files changed, 105 insertions(+), 30 deletions(-) diff --git a/conjure-go-client/httpclient/client_builder.go b/conjure-go-client/httpclient/client_builder.go index 1bf8461d..73d24492 100644 --- a/conjure-go-client/httpclient/client_builder.go +++ b/conjure-go-client/httpclient/client_builder.go @@ -17,7 +17,6 @@ package httpclient import ( "context" - "crypto/tls" "fmt" "net/http" "time" @@ -70,10 +69,12 @@ type clientBuilder struct { } type httpClientBuilder struct { - ServiceName refreshable.String - Timeout refreshable.Duration - DialerParams refreshingclient.RefreshableDialerParams - TLSConfig *tls.Config // If unset, config in TransportParams will be used. + ServiceName refreshable.String + Timeout refreshable.Duration + DialerParams refreshingclient.RefreshableDialerParams + // TLSConfig supplies the *tls.Config for the underlying transport to use. + // If unset, config in TransportParams will be used. + TLSConfig refreshingclient.RefreshableTLSConf TransportParams refreshingclient.RefreshableTransportParams Middlewares []Middleware @@ -97,11 +98,11 @@ func (b *httpClientBuilder) Build(ctx context.Context, params ...HTTPClientParam } } - var tlsProvider refreshingclient.TLSProvider + var tlsProvider refreshingclient.RefreshableTLSConf if b.TLSConfig != nil { - tlsProvider = refreshingclient.NewStaticTLSConfigProvider(b.TLSConfig) + tlsProvider = b.TLSConfig } else { - refreshableProvider, err := refreshingclient.NewRefreshableTLSConfig(ctx, b.TransportParams.TLS()) + refreshableProvider, err := refreshingclient.NewRefreshableTLSConfigFromParams(ctx, b.TransportParams.TLS()) if err != nil { return nil, err } diff --git a/conjure-go-client/httpclient/client_params.go b/conjure-go-client/httpclient/client_params.go index ef3bbeb9..13d03827 100644 --- a/conjure-go-client/httpclient/client_params.go +++ b/conjure-go-client/httpclient/client_params.go @@ -354,7 +354,7 @@ func WithTLSConfig(conf *tls.Config) ClientOrHTTPClientParam { if conf == nil { b.TLSConfig = nil } else { - b.TLSConfig = conf.Clone() + b.TLSConfig = refreshingclient.NewStaticTLSConfigProvider(conf.Clone()) } return nil }) @@ -366,7 +366,11 @@ func WithTLSConfig(conf *tls.Config) ClientOrHTTPClientParam { func WithTLSInsecureSkipVerify() ClientOrHTTPClientParam { return clientOrHTTPClientParamFunc(func(b *httpClientBuilder) error { if b.TLSConfig != nil { - b.TLSConfig.InsecureSkipVerify = true + b.TLSConfig = refreshingclient.ConfigureTLSConfig(b.TLSConfig, func(conf *tls.Config) *tls.Config { + conf = conf.Clone() + conf.InsecureSkipVerify = true + return conf + }) } b.TransportParams = refreshingclient.ConfigureTransport(b.TransportParams, func(p refreshingclient.TransportParams) refreshingclient.TransportParams { p.TLS.InsecureSkipVerify = true diff --git a/conjure-go-client/httpclient/internal/refreshingclient/tlsconfig.go b/conjure-go-client/httpclient/internal/refreshingclient/tlsconfig.go index 401e66cd..00d74720 100644 --- a/conjure-go-client/httpclient/internal/refreshingclient/tlsconfig.go +++ b/conjure-go-client/httpclient/internal/refreshingclient/tlsconfig.go @@ -17,6 +17,8 @@ package refreshingclient import ( "context" "crypto/tls" + "sync" + "sync/atomic" "github.com/palantir/pkg/refreshable" "github.com/palantir/pkg/tlsconfig" @@ -24,6 +26,54 @@ import ( "github.com/palantir/witchcraft-go-logging/wlog/svclog/svc1log" ) +type RefreshableTLSConf interface { + GetTLSConfig(ctx context.Context) *tls.Config + SubscribeToTLSConfig(consumer func(*tls.Config)) (unsubscribe func()) +} + +var _ RefreshableTLSConf = (*MappedRefreshableTLSConfig)(nil) + +func ConfigureTLSConfig(r RefreshableTLSConf, mapFn func(conf *tls.Config) *tls.Config) RefreshableTLSConf { + var m MappedRefreshableTLSConfig + r.SubscribeToTLSConfig(func(c *tls.Config) { + m.update(mapFn(c)) + }) + return &m +} + +type MappedRefreshableTLSConfig struct { + conf atomic.Pointer[tls.Config] + + mu sync.Mutex // protects subscribers + subscribers []*func(*tls.Config) +} + +// GetTLSConfig implements RefreshableTLSConf. +func (m *MappedRefreshableTLSConfig) GetTLSConfig(ctx context.Context) *tls.Config { + return m.conf.Load() +} + +func (m *MappedRefreshableTLSConfig) update(conf *tls.Config) { + m.conf.Store(conf) + + m.mu.Lock() + defer m.mu.Unlock() + for _, sub := range m.subscribers { + (*sub)(conf) + } +} + +// SubscribeToTLSConfig implements RefreshableTLSConf. +func (m *MappedRefreshableTLSConfig) SubscribeToTLSConfig(consumer func(*tls.Config)) (unsubscribe func()) { + m.mu.Lock() + defer m.mu.Unlock() + + consumerFnPtr := &consumer + m.subscribers = append(m.subscribers, consumerFnPtr) + // TODO(smenon): implement unsubcribe + return func() {} +} + // TLSParams contains the parameters needed to build a *tls.Config. // Its fields must all be compatible with reflect.DeepEqual. type TLSParams struct { @@ -40,7 +90,7 @@ type TLSProvider interface { // StaticTLSConfigProvider is a TLSProvider that always returns the same *tls.Config. type StaticTLSConfigProvider tls.Config -func NewStaticTLSConfigProvider(tlsConfig *tls.Config) *StaticTLSConfigProvider { +func NewStaticTLSConfigProvider(tlsConfig *tls.Config) RefreshableTLSConf { return (*StaticTLSConfigProvider)(tlsConfig) } @@ -48,37 +98,49 @@ func (p *StaticTLSConfigProvider) GetTLSConfig(context.Context) *tls.Config { return (*tls.Config)(p) } -type RefreshableTLSConfig struct { +// SubscribeToTLSConfig implements RefreshableTLSConf. +func (p *StaticTLSConfigProvider) SubscribeToTLSConfig(consumer func(*tls.Config)) (unsubscribe func()) { + return nil +} + +type WrappedRefreshableTLSConfig struct { r *refreshable.ValidatingRefreshable // contains *tls.Config } -// NewRefreshableTLSConfig evaluates the provided TLSParams and returns a RefreshableTLSConfig that will update the +// NewRefreshableTLSConfigFromParams evaluates the provided TLSParams and returns a RefreshableTLSConfig that will update the // underlying *tls.Config when the TLSParams change. -// IF the initial TLSParams are invalid, NewRefreshableTLSConfig will return an error. +// IF the initial TLSParams are invalid, NewRefreshableTLSConfigFromParams will return an error. // If the updated TLSParams are invalid, the RefreshableTLSConfig will continue to use the previous value and log the error. // // N.B. This subscription only fires when the paths are updated, not when the contents of the files are updated. // We could consider adding a file refreshable to watch the key and cert files. -func NewRefreshableTLSConfig(ctx context.Context, params RefreshableTLSParams) (TLSProvider, error) { +func NewRefreshableTLSConfigFromParams(ctx context.Context, params RefreshableTLSParams) (RefreshableTLSConf, error) { r, err := refreshable.NewMapValidatingRefreshable(params, func(i interface{}) (interface{}, error) { return NewTLSConfig(ctx, i.(TLSParams)) }) if err != nil { return nil, werror.WrapWithContextParams(ctx, err, "failed to build RefreshableTLSConfig") } - return RefreshableTLSConfig{r: r}, nil + return WrappedRefreshableTLSConfig{r: r}, nil } // GetTLSConfig returns the most recent valid *tls.Config. // If the last refreshable update resulted in an error, that error is logged and // the previous value is returned. -func (r RefreshableTLSConfig) GetTLSConfig(ctx context.Context) *tls.Config { +func (r WrappedRefreshableTLSConfig) GetTLSConfig(ctx context.Context) *tls.Config { if err := r.r.LastValidateErr(); err != nil { svc1log.FromContext(ctx).Warn("Invalid TLS config. Using previous value.", svc1log.Stacktrace(err)) } return r.r.Current().(*tls.Config) } +// SubscribeToTLSConfig implements RefreshableTLSConf. +func (r WrappedRefreshableTLSConfig) SubscribeToTLSConfig(consumer func(*tls.Config)) (unsubscribe func()) { + return r.r.Subscribe(func(i interface{}) { + consumer(i.(*tls.Config)) + }) +} + // NewTLSConfig returns a *tls.Config built from the provided TLSParams. func NewTLSConfig(ctx context.Context, p TLSParams) (*tls.Config, error) { var tlsParams []tlsconfig.ClientParam diff --git a/conjure-go-client/httpclient/internal/refreshingclient/transport.go b/conjure-go-client/httpclient/internal/refreshingclient/transport.go index dc7c9b39..c4a8c972 100644 --- a/conjure-go-client/httpclient/internal/refreshingclient/transport.go +++ b/conjure-go-client/httpclient/internal/refreshingclient/transport.go @@ -16,11 +16,12 @@ package refreshingclient import ( "context" + "crypto/tls" "net/http" "net/url" + "sync/atomic" "time" - "github.com/palantir/pkg/refreshable" "github.com/palantir/witchcraft-go-logging/wlog/svclog/svc1log" "golang.org/x/net/http2" ) @@ -42,12 +43,21 @@ type TransportParams struct { TLS TLSParams } -func NewRefreshableTransport(ctx context.Context, p RefreshableTransportParams, tlsProvider TLSProvider, dialer ContextDialer) http.RoundTripper { - return &RefreshableTransport{ - Refreshable: p.MapTransportParams(func(p TransportParams) interface{} { - return newTransport(ctx, p, tlsProvider, dialer) - }), - } +func NewRefreshableTransport(ctx context.Context, p RefreshableTransportParams, t RefreshableTLSConf, dialer ContextDialer) http.RoundTripper { + var refreshingTransport RefreshableTransport + + // initialize the transport the first time. + refreshingTransport.Update(ctx, p.CurrentTransportParams(), t.GetTLSConfig(ctx), dialer) + + // also subscribe to updates on transport params and the tls provider. + p.SubscribeToTransportParams(func(tp TransportParams) { + refreshingTransport.Update(ctx, tp, t.GetTLSConfig(ctx), dialer) + }) + t.SubscribeToTLSConfig(func(conf *tls.Config) { + refreshingTransport.Update(ctx, p.CurrentTransportParams(), conf, dialer) + }) + + return &refreshingTransport } // ConfigureTransport accepts a mapping function which will be applied to the params value as it is evaluated. @@ -61,14 +71,14 @@ func ConfigureTransport(r RefreshableTransportParams, mapFn func(p TransportPara // RefreshableTransport implements http.RoundTripper backed by a refreshable *http.Transport. // The transport and internal dialer are each rebuilt when any of their respective parameters are updated. type RefreshableTransport struct { - refreshable.Refreshable // contains *http.Transport + t atomic.Pointer[http.Transport] } func (r *RefreshableTransport) RoundTrip(req *http.Request) (*http.Response, error) { - return r.Current().(*http.Transport).RoundTrip(req) + return r.t.Load().RoundTrip(req) } -func newTransport(ctx context.Context, p TransportParams, tlsProvider TLSProvider, dialer ContextDialer) *http.Transport { +func (r *RefreshableTransport) Update(ctx context.Context, p TransportParams, tlsConfig *tls.Config, dialer ContextDialer) { svc1log.FromContext(ctx).Debug("Reconstructing HTTP Transport") var transportProxy func(*http.Request) (*url.URL, error) @@ -78,7 +88,6 @@ func newTransport(ctx context.Context, p TransportParams, tlsProvider TLSProvide transportProxy = http.ProxyFromEnvironment } - tlsConfig := tlsProvider.GetTLSConfig(ctx) transport := &http.Transport{ Proxy: transportProxy, DialContext: dialer.DialContext, @@ -115,6 +124,5 @@ func newTransport(ctx context.Context, p TransportParams, tlsProvider TLSProvide http2Transport.PingTimeout = p.HTTP2PingTimeout } } - - return transport + r.t.Store(transport) } From 853eafafeaef3151e7099def5c01c450ee117900 Mon Sep 17 00:00:00 2001 From: Sid Menon Date: Wed, 27 Nov 2024 17:40:44 -0500 Subject: [PATCH 2/2] (2/2): pass in refreshable --- conjure-go-client/httpclient/client_params.go | 20 +++++++++++++++++++ .../internal/refreshingclient/tlsconfig.go | 18 +++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/conjure-go-client/httpclient/client_params.go b/conjure-go-client/httpclient/client_params.go index 13d03827..0cea1627 100644 --- a/conjure-go-client/httpclient/client_params.go +++ b/conjure-go-client/httpclient/client_params.go @@ -360,6 +360,26 @@ func WithTLSConfig(conf *tls.Config) ClientOrHTTPClientParam { }) } +// WithRefreshableTLSConfig allows a user to pass in a refreshable *tls.Config. +// Note that it is undesirable from a performance perspective to pass in a *refreshable.DefaultRefreshable +// here, as that takes a dependency on reflect.DeepEqual. Updates to the underlying *tls.Config cause +// a complete transport refresh, which is a relatively expensive operation. +// Clients are responsible for passing in a refreshable which fulfills their performance requirements. +func WithRefreshableTLSConfig(r refreshable.Refreshable) ClientOrHTTPClientParam { + return clientOrHTTPClientParamFunc(func(b *httpClientBuilder) error { + var err error + if r == nil { + b.TLSConfig = nil + } else { + b.TLSConfig, err = refreshingclient.NewRefreshableTLSConfigFromRefreshable(r) + if err != nil { + return err + } + } + return nil + }) +} + // WithTLSInsecureSkipVerify sets the InsecureSkipVerify field for the HTTP client's tls config. // This option should only be used in clients that have way to establish trust with servers. // If WithTLSConfig is used, the config's InsecureSkipVerify is set to true. diff --git a/conjure-go-client/httpclient/internal/refreshingclient/tlsconfig.go b/conjure-go-client/httpclient/internal/refreshingclient/tlsconfig.go index 00d74720..57c93776 100644 --- a/conjure-go-client/httpclient/internal/refreshingclient/tlsconfig.go +++ b/conjure-go-client/httpclient/internal/refreshingclient/tlsconfig.go @@ -17,6 +17,7 @@ package refreshingclient import ( "context" "crypto/tls" + "errors" "sync" "sync/atomic" @@ -124,6 +125,23 @@ func NewRefreshableTLSConfigFromParams(ctx context.Context, params RefreshableTL return WrappedRefreshableTLSConfig{r: r}, nil } +func NewRefreshableTLSConfigFromRefreshable(r refreshable.Refreshable) (RefreshableTLSConf, error) { + validating, err := refreshable.NewValidatingRefreshable(r, func(i interface{}) error { + _, ok := r.Current().(*tls.Config) + if !ok { + // TODO(smenon): proper error msg. + return errors.New("invalid type for refreshable") + } + return nil + }) + if err != nil { + return nil, werror.Wrap(err, "failed to build RefreshableTLSConfig") + } + return WrappedRefreshableTLSConfig{ + r: validating, + }, nil +} + // GetTLSConfig returns the most recent valid *tls.Config. // If the last refreshable update resulted in an error, that error is logged and // the previous value is returned.