diff --git a/pkg/internal/httprequest/httprequest.go b/pkg/internal/httprequest/httprequest.go index bd72b9ffc..c2a098ec9 100644 --- a/pkg/internal/httprequest/httprequest.go +++ b/pkg/internal/httprequest/httprequest.go @@ -13,6 +13,7 @@ import ( "fmt" "io" "net/http" + "slices" "time" "github.com/trustbloc/wallet-sdk/pkg/api" @@ -42,11 +43,28 @@ func New(httpClient httpClient, metricsLogger api.MetricsLogger) *Request { func (r *Request) Do(method, endpointURL, contentType string, body io.Reader, event, parentEvent string, errorResponseHandler func(statusCode int, responseBody []byte) error, ) ([]byte, error) { - req, err := http.NewRequestWithContext(context.Background(), method, endpointURL, body) + return r.DoContext(context.Background(), method, endpointURL, contentType, + nil, body, event, parentEvent, nil, errorResponseHandler) +} + +var defaultAcceptableStatuses = []int{http.StatusOK} + +// DoContext is the same as Do, but also accept context and headers. +func (r *Request) DoContext(ctx context.Context, method, endpointURL, contentType string, + additionalHeaders http.Header, body io.Reader, event, parentEvent string, acceptableStatuses []int, + errorResponseHandler func(statusCode int, responseBody []byte) error, +) ([]byte, error) { + req, err := http.NewRequestWithContext(ctx, method, endpointURL, body) if err != nil { return nil, err } + for header, values := range additionalHeaders { + for _, value := range values { + req.Header.Add(header, value) + } + } + if contentType != "" { req.Header.Add("Content-Type", contentType) } @@ -79,9 +97,14 @@ func (r *Request) Do(method, endpointURL, contentType string, body io.Reader, return nil, err } - if resp.StatusCode != http.StatusOK { + statuses := acceptableStatuses + if statuses == nil { + statuses = defaultAcceptableStatuses + } + + if !slices.Contains(statuses, resp.StatusCode) { if errorResponseHandler == nil { - errorResponseHandler = genericErrorResponseHandler + errorResponseHandler = genericErrorResponseHandler(statuses) } return nil, errorResponseHandler(resp.StatusCode, respBytes) @@ -106,8 +129,16 @@ func (r *Request) DoAndParse(method, endpointURL, contentType string, body io.Re return json.Unmarshal(respBytes, response) } -func genericErrorResponseHandler(statusCode int, respBytes []byte) error { - return fmt.Errorf( - "expected status code %d but got status code %d with response body %s instead", - http.StatusOK, statusCode, respBytes) +func genericErrorResponseHandler(expectedStatusCodes []int) func(statusCode int, respBytes []byte) error { + return func(statusCode int, respBytes []byte) error { + if len(expectedStatusCodes) == 1 { + return fmt.Errorf( + "expected status code %d but got status code %d with response body %s instead", + expectedStatusCodes[0], statusCode, respBytes) + } + + return fmt.Errorf( + "expected status codes %v but got status code %d with response body %s instead", + expectedStatusCodes, statusCode, respBytes) + } } diff --git a/pkg/oauth2/clientregistration.go b/pkg/oauth2/clientregistration.go index 755b62144..2bfae8a30 100644 --- a/pkg/oauth2/clientregistration.go +++ b/pkg/oauth2/clientregistration.go @@ -10,11 +10,18 @@ package oauth2 import ( "bytes" + "context" "encoding/json" "errors" "fmt" - "io" "net/http" + + "github.com/trustbloc/wallet-sdk/pkg/internal/httprequest" +) + +const ( + newRegisterClientEventText = "Register client" + fetchRequestObjectEventText = "Fetch request object via an HTTP GET request to %s" ) // RegisterClient registers a new client at the given registration endpoint. @@ -55,39 +62,15 @@ func RegisterClient(registrationEndpoint string, clientMetadata *ClientMetadata, } func getRawResponse(requestBytes []byte, registrationEndpoint string, opts *opts) ([]byte, error) { - httpReq, err := http.NewRequest( //nolint: noctx // Timeout expected to be set in HTTP client already - http.MethodPost, registrationEndpoint, bytes.NewReader(requestBytes)) - if err != nil { - return nil, err - } - - httpReq.Header.Set("Content-Type", "application/json") - + headers := http.Header{} if opts.initialAccessBearerToken != "" { - httpReq.Header.Set("Authorization", "Bearer "+opts.initialAccessBearerToken) - } - - resp, err := opts.httpClient.Do(httpReq) - if err != nil { - return nil, err + headers.Set("Authorization", "Bearer "+opts.initialAccessBearerToken) } - defer func() { - errClose := resp.Body.Close() - if errClose != nil { - println(fmt.Sprintf("failed to close response body: %s", errClose.Error())) - } - }() - - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err - } - - if resp.StatusCode != http.StatusCreated { - return nil, fmt.Errorf("server returned status code %d with body [%s]", resp.StatusCode, - string(respBody)) - } + metricsEvent := fmt.Sprintf(fetchRequestObjectEventText, registrationEndpoint) - return respBody, nil + return httprequest.New(opts.httpClient, opts.metricsLogger).DoContext(context.TODO(), + http.MethodPost, registrationEndpoint, "application/json", headers, + bytes.NewReader(requestBytes), metricsEvent, newRegisterClientEventText, + []int{http.StatusCreated}, nil) } diff --git a/pkg/oauth2/clientregistration_test.go b/pkg/oauth2/clientregistration_test.go index ca9e645a8..bf8d91e71 100644 --- a/pkg/oauth2/clientregistration_test.go +++ b/pkg/oauth2/clientregistration_test.go @@ -86,7 +86,7 @@ func TestRegisterClient(t *testing.T) { defer server.Close() response, err := oauth2.RegisterClient(server.URL, nil) - require.EqualError(t, err, "server returned status code 500 with body []") + require.ErrorContains(t, err, "expected status code 201 but got status code 500 with response body instead") require.Nil(t, response) }) t.Run("Server returns empty body, resulting in a JSON unmarshal failure", func(t *testing.T) { diff --git a/pkg/oauth2/opts.go b/pkg/oauth2/opts.go index fe3827a58..e96776d75 100644 --- a/pkg/oauth2/opts.go +++ b/pkg/oauth2/opts.go @@ -4,11 +4,13 @@ import ( "net/http" "github.com/trustbloc/wallet-sdk/pkg/api" + "github.com/trustbloc/wallet-sdk/pkg/metricslogger/noop" ) type opts struct { initialAccessBearerToken string httpClient *http.Client + metricsLogger api.MetricsLogger } // An Opt is a single option for a call to RegisterClient. @@ -29,6 +31,15 @@ func WithHTTPClient(httpClient *http.Client) Opt { } } +// WithMetricsLogger is an option for a call to RegisterClient that allows a caller to specify their MetricsLogger. +// If used, then performance metrics events will be pushed to the given MetricsLogger implementation. +// If this option is not used, then metrics logging will be disabled. +func WithMetricsLogger(metricsLogger api.MetricsLogger) Opt { + return func(opts *opts) { + opts.metricsLogger = metricsLogger + } +} + func processOpts(options []Opt) *opts { opts := mergeOpts(options) @@ -48,5 +59,9 @@ func mergeOpts(options []Opt) *opts { } } + if resolveOpts.metricsLogger == nil { + resolveOpts.metricsLogger = noop.NewMetricsLogger() + } + return resolveOpts } diff --git a/pkg/openid4ci/interaction.go b/pkg/openid4ci/interaction.go index 2065bd518..f34e77294 100644 --- a/pkg/openid4ci/interaction.go +++ b/pkg/openid4ci/interaction.go @@ -15,7 +15,6 @@ import ( "encoding/json" "errors" "fmt" - "io" "net/http" "net/url" "strings" @@ -33,6 +32,7 @@ import ( "github.com/trustbloc/wallet-sdk/pkg/common" diderrors "github.com/trustbloc/wallet-sdk/pkg/did" "github.com/trustbloc/wallet-sdk/pkg/did/wellknown" + "github.com/trustbloc/wallet-sdk/pkg/internal/httprequest" metadatafetcher "github.com/trustbloc/wallet-sdk/pkg/internal/issuermetadata" "github.com/trustbloc/wallet-sdk/pkg/models/issuer" "github.com/trustbloc/wallet-sdk/pkg/walleterror" @@ -380,19 +380,21 @@ func (i *interaction) getCredentialResponse(signer api.JWTSigner, nonce any, oAuthHTTPClient := createOAuthHTTPClient(i.oAuth2Config, i.authToken, i.httpClient) for index := range credentialTypes { - request, err := i.createCredentialRequestWithoutAccessToken(proofJWT, credentialFormats[index], + requestBody, err := i.createCredentialRequestBody(proofJWT, credentialFormats[index], credentialTypes[index], credentialContexts[index]) if err != nil { return nil, err } - // The access token header will be injected automatically by the OAuth HTTP client, so there's no need to - // explicitly set it on the request object generated by the method call above. - fetchCredentialResponseEventText := fmt.Sprintf(fetchCredentialViaGETReqEventText, index+1, len(credentialTypes), i.issuerMetadata.CredentialEndpoint) - responseBytes, err := i.getRawCredentialResponse(request, fetchCredentialResponseEventText, oAuthHTTPClient) + // The access token header will be injected automatically by the OAuth HTTP client, so there's no need to + // explicitly set it on the request object generated by the method call above. + responseBytes, err := httprequest.New(oAuthHTTPClient, i.metricsLogger).DoContext(context.TODO(), + http.MethodPost, i.issuerMetadata.CredentialEndpoint, "application/json", nil, + bytes.NewReader(requestBody), fetchCredentialResponseEventText, requestCredentialEventText, + []int{http.StatusOK, http.StatusCreated}, processCredentialErrorResponse) if err != nil { return nil, err } @@ -461,12 +463,9 @@ func createOAuthHTTPClient( return oAuthHTTPClient } -// The returned *http.Request will not have the access token set on it. The caller must ensure that it's set -// before sending the request to the server. -func (i *interaction) createCredentialRequestWithoutAccessToken(proofJWT, credentialFormat string, +func (i *interaction) createCredentialRequestBody(proofJWT, credentialFormat string, credentialTypes, credentialContext []string, -) (*http.Request, error) { - +) ([]byte, error) { var credentialContextToSend *[]string if len(credentialContext) > 0 { @@ -485,57 +484,7 @@ func (i *interaction) createCredentialRequestWithoutAccessToken(proofJWT, creden }, } - credentialReqBytes, err := json.Marshal(credentialReq) - if err != nil { - return nil, err - } - - request, err := http.NewRequest(http.MethodPost, //nolint: noctx - i.issuerMetadata.CredentialEndpoint, bytes.NewReader(credentialReqBytes)) - if err != nil { - return nil, err - } - - request.Header.Add("Content-Type", "application/json") - - return request, nil -} - -func (i *interaction) getRawCredentialResponse(credentialReq *http.Request, eventText string, httpClient *http.Client, -) ([]byte, error) { - timeStartHTTPRequest := time.Now() - - response, err := httpClient.Do(credentialReq) - if err != nil { - return nil, err - } - - err = i.metricsLogger.Log(&api.MetricsEvent{ - Event: eventText, - ParentEvent: requestCredentialEventText, - Duration: time.Since(timeStartHTTPRequest), - }) - if err != nil { - return nil, err - } - - responseBytes, err := io.ReadAll(response.Body) - if err != nil { - return nil, err - } - - if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusCreated { - return nil, processCredentialErrorResponse(response.StatusCode, responseBytes) - } - - defer func() { - errClose := response.Body.Close() - if errClose != nil { - println(fmt.Sprintf("failed to close response body: %s", errClose.Error())) - } - }() - - return responseBytes, nil + return json.Marshal(credentialReq) } func (i *interaction) getVCsFromCredentialResponses( diff --git a/pkg/openid4ci/issuerinitiatedinteraction.go b/pkg/openid4ci/issuerinitiatedinteraction.go index 6ab7e1428..fd9945853 100644 --- a/pkg/openid4ci/issuerinitiatedinteraction.go +++ b/pkg/openid4ci/issuerinitiatedinteraction.go @@ -443,19 +443,22 @@ func (i *IssuerInitiatedInteraction) getCredentialResponse( credentialResponses := make([]CredentialResponse, len(i.credentialTypes)) for index := range i.credentialTypes { - request, err := i.interaction.createCredentialRequestWithoutAccessToken(proofJWT, i.credentialFormats[index], + requestBody, err := i.interaction.createCredentialRequestBody(proofJWT, i.credentialFormats[index], i.credentialTypes[index], i.credentialContexts[index]) if err != nil { return nil, err } - request.Header.Add("Authorization", "Bearer "+tokenResponse.AccessToken) + headers := http.Header{} + headers.Add("Authorization", "Bearer "+tokenResponse.AccessToken) fetchCredentialResponseEventText := fmt.Sprintf(fetchCredentialViaGETReqEventText, index+1, len(i.credentialTypes), i.interaction.issuerMetadata.CredentialEndpoint) - responseBytes, err := i.interaction.getRawCredentialResponse(request, fetchCredentialResponseEventText, - i.interaction.httpClient) + responseBytes, err := httprequest.New(i.interaction.httpClient, i.interaction.metricsLogger).DoContext(context.TODO(), + http.MethodPost, i.interaction.issuerMetadata.CredentialEndpoint, "application/json", headers, + bytes.NewReader(requestBody), fetchCredentialResponseEventText, requestCredentialEventText, + []int{http.StatusOK, http.StatusCreated}, processCredentialErrorResponse) if err != nil { return nil, err } @@ -505,22 +508,16 @@ func (i *IssuerInitiatedInteraction) getCredentialResponsesBatch( return nil, err } - request, err := http.NewRequestWithContext(context.Background(), - http.MethodPost, - i.interaction.issuerMetadata.BatchCredentialEndpoint, - bytes.NewReader(b), - ) - if err != nil { - return nil, err - } - - request.Header.Add("Content-Type", "application/json") - request.Header.Add("Authorization", "Bearer "+tokenResponse.AccessToken) + headers := http.Header{} + headers.Add("Authorization", "Bearer "+tokenResponse.AccessToken) fetchCredentialResponseEventText := fmt.Sprintf(fetchCredentialViaGETReqEventText, numberOfCredentials, numberOfCredentials, i.interaction.issuerMetadata.BatchCredentialEndpoint) - b, err = i.interaction.getRawCredentialResponse(request, fetchCredentialResponseEventText, i.interaction.httpClient) + b, err = httprequest.New(i.interaction.httpClient, i.interaction.metricsLogger).DoContext(context.TODO(), + http.MethodPost, i.interaction.issuerMetadata.BatchCredentialEndpoint, "application/json", headers, + bytes.NewReader(b), fetchCredentialResponseEventText, requestCredentialEventText, + []int{http.StatusOK, http.StatusCreated}, processCredentialErrorResponse) if err != nil { return nil, err } diff --git a/pkg/openid4vp/acknowledgment.go b/pkg/openid4vp/acknowledgment.go index 061374ece..a5feaa7a6 100644 --- a/pkg/openid4vp/acknowledgment.go +++ b/pkg/openid4vp/acknowledgment.go @@ -8,12 +8,14 @@ package openid4vp import ( "bytes" - "context" "encoding/base64" "encoding/json" "fmt" "net/http" "net/url" + + "github.com/trustbloc/wallet-sdk/pkg/internal/httprequest" + "github.com/trustbloc/wallet-sdk/pkg/metricslogger/noop" ) const ( @@ -50,26 +52,11 @@ func (a *Acknowledgment) AcknowledgeVerifier(error, desc string, httpClient http v.Add("interaction_details", base64.StdEncoding.EncodeToString(interactionDetailsBytes)) } - req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, a.ResponseURI, - bytes.NewBufferString(v.Encode())) + _, err := httprequest.New(httpClient, noop.NewMetricsLogger()).Do(http.MethodPost, a.ResponseURI, + "application/x-www-form-urlencoded", bytes.NewBufferString(v.Encode()), "", "", nil) if err != nil { return err } - req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - - resp, err := httpClient.Do(req) - if err != nil { - return err - } - - defer func() { - _ = resp.Body.Close() - }() - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("unexpected status code: %d", resp.StatusCode) - } - return nil } diff --git a/pkg/openid4vp/openid4vp_test.go b/pkg/openid4vp/openid4vp_test.go index 3d83eaf5e..b31a3a038 100644 --- a/pkg/openid4vp/openid4vp_test.go +++ b/pkg/openid4vp/openid4vp_test.go @@ -1130,7 +1130,7 @@ func TestOpenID4VP_PresentedClaims(t *testing.T) { "spouse":{}, "degree":{ "degree":{}, - "type":{} + "type":{} } } `, string(claimsJSON)) @@ -1224,7 +1224,7 @@ func TestAcknowledgment_AcknowledgeVerifier(t *testing.T) { }, ) - require.ErrorContains(t, err, "unexpected status code") + require.ErrorContains(t, err, "but got status code 500") }) }