diff --git a/src/Microsoft.Health.Fhir.Api.UnitTests/Features/Audit/SmartOnFhirAuditLoggingFilterAttributeTests.cs b/src/Microsoft.Health.Fhir.Api.UnitTests/Features/Audit/SmartOnFhirAuditLoggingFilterAttributeTests.cs new file mode 100644 index 0000000000..6cd4a70927 --- /dev/null +++ b/src/Microsoft.Health.Fhir.Api.UnitTests/Features/Audit/SmartOnFhirAuditLoggingFilterAttributeTests.cs @@ -0,0 +1,182 @@ +// ------------------------------------------------------------------------------------------------- +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +// ------------------------------------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Internal; +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Mvc.Controllers; +using Microsoft.AspNetCore.Mvc.Filters; +using Microsoft.AspNetCore.Routing; +using Microsoft.Extensions.Primitives; +using Microsoft.Health.Fhir.Api.Features.Audit; +using Microsoft.Health.Fhir.Core.Features.Context; +using NSubstitute; +using Xunit; + +namespace Microsoft.Health.Fhir.Api.UnitTests.Features.Audit +{ + public class SmartOnFhirAuditLoggingFilterAttributeTests + { + private const string ControllerName = "controller"; + private const string ActionName = "action"; + private const string Action = "smart-on-fhir-action"; + + private readonly IAuditLogger _auditLogger = Substitute.For(); + private readonly IFhirRequestContextAccessor _fhirRequestContextAccessor = Substitute.For(); + private readonly SmartOnFhirAuditLoggingFilterAttribute _filter; + private IReadOnlyCollection> _loggedClaims; + private readonly QueryCollection _queryCollection; + private readonly FormCollection _formCollection; + private readonly string _correlationId; + + public SmartOnFhirAuditLoggingFilterAttributeTests() + { + _correlationId = Guid.NewGuid().ToString(); + _fhirRequestContextAccessor.FhirRequestContext.CorrelationId.Returns(_correlationId); + + _filter = new SmartOnFhirAuditLoggingFilterAttribute(Action, _auditLogger, _fhirRequestContextAccessor); + _auditLogger.LogAudit(Arg.Any(), Arg.Any(), Arg.Any(), Arg.Any(), Arg.Any(), Arg.Any(), Arg.Do>>(c => _loggedClaims = c)); + + var passedValues = new Dictionary + { + { "client_id", new StringValues("1234") }, + { "secret", new StringValues("secret") }, + }; + + _queryCollection = new QueryCollection( + passedValues); + + _formCollection = new FormCollection( + passedValues); + } + + [Fact] + public void GivenAController_WhenExecutingAction_ThenAuditLogShouldBeLogged() + { + SetupExecutingAction(_queryCollection, null); + + VerifyAuditLoggerReceivedLogAudit(AuditAction.Executing, null); + VerifyClaims(); + } + + [Fact] + public void GivenAController_WhenExecutingActionWithFormCollection_ThenAuditLogShouldBeLogged() + { + SetupExecutingAction(null, _formCollection); + + VerifyAuditLoggerReceivedLogAudit(AuditAction.Executing, null); + VerifyClaims(); + } + + [Fact] + public void GivenAController_WhenExecutingActionWithEmptyQueryCollectionAndEmptyFormCollection_ThenAuditLogShouldBeLogged() + { + SetupExecutingAction(null, null); + + VerifyAuditLoggerReceivedLogAudit(AuditAction.Executing, null); + Assert.Empty(_loggedClaims); + } + + [Fact] + public void GivenAController_WhenExecutedAction_ThenAuditLogShouldBeLogged() + { + const HttpStatusCode expectedStatusCode = HttpStatusCode.InternalServerError; + SetupExecutedAction(expectedStatusCode, new OkResult(), _queryCollection, null); + + VerifyAuditLoggerReceivedLogAudit(AuditAction.Executed, expectedStatusCode); + VerifyClaims(); + } + + [Fact] + public void GivenAController_WhenExecutedActionWithFormCollection_ThenAuditLogShouldBeLogged() + { + const HttpStatusCode expectedStatusCode = HttpStatusCode.InternalServerError; + SetupExecutedAction(expectedStatusCode, new OkResult(), null, _formCollection); + + VerifyAuditLoggerReceivedLogAudit(AuditAction.Executed, expectedStatusCode); + VerifyClaims(); + } + + [Fact] + public void GivenAController_WhenExecutedActionWithEmptyQueryCollectionAndEmptyFormCollection_ThenAuditLogShouldBeLogged() + { + const HttpStatusCode expectedStatusCode = HttpStatusCode.InternalServerError; + SetupExecutedAction(expectedStatusCode, new OkResult(), null, null); + + VerifyAuditLoggerReceivedLogAudit(AuditAction.Executed, expectedStatusCode); + Assert.Empty(_loggedClaims); + } + + private void SetupExecutingAction(IQueryCollection queryCollection, IFormCollection formCollection) + { + var actionExecutingContext = new ActionExecutingContext( + new ActionContext(new DefaultHttpContext(), new RouteData(), new ControllerActionDescriptor() { DisplayName = "Executing Context Test Descriptor" }), + new List(), + new Dictionary(), + new MockController()); + + actionExecutingContext.HttpContext.Request.Query = queryCollection; + actionExecutingContext.HttpContext.Request.Form = formCollection; + + actionExecutingContext.ActionDescriptor = new ControllerActionDescriptor() + { + ControllerName = ControllerName, + ActionName = ActionName, + }; + + _filter.OnActionExecuting(actionExecutingContext); + } + + private void SetupExecutedAction(HttpStatusCode expectedStatusCode, IActionResult result, IQueryCollection queryCollection, IFormCollection formCollection) + { + var resultExecutedContext = new ResultExecutedContext( + new ActionContext(new DefaultHttpContext(), new RouteData(), new ControllerActionDescriptor() { DisplayName = "Executed Context Test Descriptor" }), + new List(), + result, + new MockController()); + + resultExecutedContext.HttpContext.Request.Query = queryCollection; + resultExecutedContext.HttpContext.Request.Form = formCollection; + + resultExecutedContext.HttpContext.Response.StatusCode = (int)expectedStatusCode; + + resultExecutedContext.ActionDescriptor = new ControllerActionDescriptor() + { + ControllerName = ControllerName, + ActionName = ActionName, + }; + + _filter.OnResultExecuted(resultExecutedContext); + } + + private void VerifyAuditLoggerReceivedLogAudit(AuditAction auditAction, HttpStatusCode? httpStatusCode) + { + _auditLogger.Received(1).LogAudit( + Arg.Is(auditAction), + Arg.Is(Action), + Arg.Is(x => x == null), + Arg.Any(), + Arg.Is(httpStatusCode), + Arg.Is(_correlationId), + Arg.Any>>()); + } + + private void VerifyClaims() + { + Assert.Equal(1, _loggedClaims.Count); + (string key, string value) = _loggedClaims.First(); + Assert.Equal("client_id", key); + Assert.Equal("1234", value); + } + + private class MockController : Controller + { + } + } +} diff --git a/src/Microsoft.Health.Fhir.Api/Controllers/AadSmartOnFhirProxyController.cs b/src/Microsoft.Health.Fhir.Api/Controllers/AadSmartOnFhirProxyController.cs index fe00047c7a..656b684f49 100644 --- a/src/Microsoft.Health.Fhir.Api/Controllers/AadSmartOnFhirProxyController.cs +++ b/src/Microsoft.Health.Fhir.Api/Controllers/AadSmartOnFhirProxyController.cs @@ -10,16 +10,18 @@ using System.Text; using System.Threading.Tasks; using EnsureThat; -using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Extensions; using Microsoft.AspNetCore.Mvc; -using Microsoft.AspNetCore.Mvc.Filters; -using Microsoft.Extensions.Configuration; +using Microsoft.AspNetCore.Routing; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using Microsoft.Health.Fhir.Api.Features.ActionResults; +using Microsoft.Health.Fhir.Api.Features.Audit; using Microsoft.Health.Fhir.Api.Features.Routing; using Microsoft.Health.Fhir.Core.Configs; using Microsoft.Health.Fhir.Core.Exceptions; +using Microsoft.Health.Fhir.Core.Features.Routing; +using Microsoft.Health.Fhir.ValueSets; using Microsoft.IdentityModel.Tokens; using Newtonsoft.Json.Linq; @@ -32,12 +34,12 @@ namespace Microsoft.Health.Fhir.Api.Controllers [Route("AadSmartOnFhirProxy")] public class AadSmartOnFhirProxyController : Controller { - private readonly SecurityConfiguration _securityConfiguration; private readonly bool _isAadV2; private readonly ILogger _logger; private readonly IHttpClientFactory _httpClientFactory; private readonly string _aadAuthorizeEndpoint; private readonly string _aadTokenEndpoint; + private readonly IUrlResolver _urlResolver; // TODO: _launchContextFields contain a list of fields that we will transmit as part of launch context, should be configurable private readonly string[] _launchContextFields = { "patient", "encounter", "practitioner", "need_patient_banner", "smart_style_url" }; @@ -45,51 +47,53 @@ public class AadSmartOnFhirProxyController : Controller /// /// Initializes a new instance of the class. /// - /// Security configuration parameters. + /// Security configuration parameters. /// HTTP Client Factory. + /// The URL resolver. /// The logger. - public AadSmartOnFhirProxyController(IOptions securityConfiguration, IHttpClientFactory httpClientFactory, ILogger logger) + public AadSmartOnFhirProxyController(IOptions securityConfigurationOptions, IHttpClientFactory httpClientFactory, IUrlResolver urlResolver, ILogger logger) { - EnsureArg.IsNotNull(securityConfiguration, nameof(securityConfiguration)); + EnsureArg.IsNotNull(securityConfigurationOptions?.Value, nameof(securityConfigurationOptions)); + EnsureArg.IsNotNull(httpClientFactory, nameof(httpClientFactory)); + EnsureArg.IsNotNull(urlResolver, nameof(urlResolver)); + EnsureArg.IsNotNull(logger, nameof(logger)); - _securityConfiguration = securityConfiguration.Value; - _isAadV2 = new Uri(_securityConfiguration.Authentication.Authority).Segments.Contains("v2.0"); - _logger = logger; + SecurityConfiguration securityConfiguration = securityConfigurationOptions.Value; + _isAadV2 = new Uri(securityConfiguration.Authentication.Authority).Segments.Contains("v2.0"); _httpClientFactory = httpClientFactory; + _urlResolver = urlResolver; + _logger = logger; - var openIdConfigurationUrl = $"{_securityConfiguration.Authentication.Authority}/.well-known/openid-configuration"; + var openIdConfigurationUrl = $"{securityConfiguration.Authentication.Authority}/.well-known/openid-configuration"; HttpResponseMessage openIdConfigurationResponse; - using (var httpClient = httpClientFactory.CreateClient()) + var httpClient = httpClientFactory.CreateClient(); + + try + { + openIdConfigurationResponse = httpClient.GetAsync(new Uri(openIdConfigurationUrl)).GetAwaiter().GetResult(); + } + catch (Exception ex) { - try + if (ex is HttpRequestException || ex is OperationCanceledException) { - openIdConfigurationResponse = httpClient.GetAsync(new Uri(openIdConfigurationUrl)).GetAwaiter().GetResult(); + logger.LogWarning(ex, $"There was an exception while attempting to read the OpenId Configuration from \"{openIdConfigurationUrl}\"."); + throw new OpenIdConfigurationException(); } - catch (Exception ex) - { - if (ex is HttpRequestException || ex is OperationCanceledException) - { - logger.LogWarning(ex, $"There was an exception while attempting to read the OpenId Configuration from \"{openIdConfigurationUrl}\"."); - throw new OpenIdConfigurationException(); - } - throw; - } + throw; } openIdConfigurationResponse.EnsureSuccessStatusCode(); var openIdConfiguration = JObject.Parse(openIdConfigurationResponse.Content.ReadAsStringAsync().GetAwaiter().GetResult()); - try - { - _aadTokenEndpoint = openIdConfiguration["token_endpoint"].Value(); - _aadAuthorizeEndpoint = openIdConfiguration["authorization_endpoint"].Value(); - } - catch (ArgumentNullException ex) + _aadTokenEndpoint = openIdConfiguration["token_endpoint"]?.Value(); + _aadAuthorizeEndpoint = openIdConfiguration["authorization_endpoint"]?.Value(); + + if (_aadTokenEndpoint == null || _aadAuthorizeEndpoint == null) { - logger.LogError($"{ex.Message}, There was an error attempting to read the endpoints from \"{openIdConfigurationUrl}\"."); + logger.LogError($"There was an error attempting to read the endpoints from \"{openIdConfigurationUrl}\"."); throw new OpenIdConfigurationException(); } } @@ -104,7 +108,9 @@ public AadSmartOnFhirProxyController(IOptions securityCon /// scope URL parameter. /// state URL parameter. /// aud (audience) URL parameter. - [HttpGet("authorize")] + [HttpGet] + [TypeFilter(typeof(SmartOnFhirAuditLoggingFilterAttribute), Arguments = new object[] { AuditEventSubType.SmartOnFhirAuthorize })] + [Route("authorize", Name = RouteNames.AadSmartOnFhirProxyAuthorize)] public ActionResult Authorize( [FromQuery(Name = "response_type")] string responseType, [FromQuery(Name = "client_id")] string clientId, @@ -124,21 +130,26 @@ public ActionResult Authorize( launch = Base64UrlEncoder.Encode("{}"); } - JObject newStateObj = JObject.Parse("{}"); - newStateObj.Add("s", state); - newStateObj.Add("l", launch); + var newStateObj = new JObject + { + { "s", state }, + { "l", launch }, + }; + + string newState = Base64UrlEncoder.Encode(newStateObj.ToString()); - string newState = Base64UrlEncoder.Encode(newStateObj.ToString(Newtonsoft.Json.Formatting.None)); + Uri callbackUrl = _urlResolver.ResolveRouteNameUrl(RouteNames.AadSmartOnFhirProxyCallback, new RouteValueDictionary { { "encodedRedirect", Base64UrlEncoder.Encode(redirectUri.ToString()) } }); - Uri callbackUrl = new Uri( - Request.Scheme + "://" + Request.Host + "/AadSmartOnFhirProxy/callback/" + - Base64UrlEncoder.Encode(redirectUri.ToString())); + var queryBuilder = new QueryBuilder + { + { "response_type", responseType }, + { "redirect_uri", callbackUrl.AbsoluteUri }, + { "client_id", clientId }, + }; - StringBuilder queryStringBuilder = new StringBuilder(); - queryStringBuilder.Append($"response_type={responseType}&redirect_uri={callbackUrl.ToString()}&client_id={clientId}"); if (!_isAadV2) { - queryStringBuilder.Append($"&resource={aud}"); + queryBuilder.Add("resource", aud); } else { @@ -147,7 +158,7 @@ public ActionResult Authorize( EnsureArg.IsNotNull(scope, nameof(scope)); var scopes = scope.Split(' '); - StringBuilder scopesBuilder = new StringBuilder(); + var scopesBuilder = new StringBuilder(); string[] wellKnownScopes = { "profile", "openid", "email", "offline_access" }; foreach (var s in scopes) @@ -163,12 +174,13 @@ public ActionResult Authorize( } var newScopes = scopesBuilder.ToString().TrimEnd(' '); - queryStringBuilder.Append($"&scope={Uri.EscapeDataString(newScopes)}"); + + queryBuilder.Add("scope", Uri.EscapeDataString(newScopes)); } - queryStringBuilder.Append($"&state={newState}"); + queryBuilder.Add("state", newState); - return Redirect($"{_aadAuthorizeEndpoint}?{queryStringBuilder.ToString()}"); + return Redirect($"{_aadAuthorizeEndpoint}{queryBuilder}"); } /// @@ -180,7 +192,9 @@ public ActionResult Authorize( /// session_state URL parameter. /// error URL parameter. /// error_description URL parameter. - [HttpGet("callback/{encodedRedirect}")] + [HttpGet] + [TypeFilter(typeof(SmartOnFhirAuditLoggingFilterAttribute), Arguments = new object[] { AuditEventSubType.SmartOnFhirCallback })] + [Route("callback/{encodedRedirect}", Name = RouteNames.AadSmartOnFhirProxyCallback)] public ActionResult Callback( string encodedRedirect, [FromQuery(Name = "code")] string code, @@ -189,19 +203,24 @@ public ActionResult Callback( [FromQuery(Name = "error")] string error, [FromQuery(Name = "error_description")] string errorDescription) { - Uri redirectUrl = new Uri(Base64UrlEncoder.Decode(encodedRedirect)); + var redirectUrl = new Uri(Base64UrlEncoder.Decode(encodedRedirect)); if (!string.IsNullOrEmpty(error)) { - return Redirect($"{redirectUrl.ToString()}?error={error}&error_description={errorDescription}"); + var errorQueryBuilder = new QueryBuilder + { + { "error", error }, + { "error_description", errorDescription }, + }; + return Redirect($"{redirectUrl}{errorQueryBuilder}"); } string compoundCode; string newState; try { - JObject launchStateParameters = JObject.Parse(Base64UrlEncoder.Decode(state)); - JObject launchParameters = JObject.Parse(Base64UrlEncoder.Decode(launchStateParameters["l"].ToString())); + var launchStateParameters = JObject.Parse(Base64UrlEncoder.Decode(state)); + var launchParameters = JObject.Parse(Base64UrlEncoder.Decode(launchStateParameters["l"].ToString())); launchParameters.Add("code", code); newState = launchStateParameters["s"].ToString(); compoundCode = Base64UrlEncoder.Encode(launchParameters.ToString(Newtonsoft.Json.Formatting.None)); @@ -212,7 +231,13 @@ public ActionResult Callback( return BadRequest("Invalid launch context parameters"); } - return Redirect($"{redirectUrl.ToString()}?code={compoundCode}&state={newState}&session_state={sessionState}"); + var queryBuilder = new QueryBuilder + { + { "code", compoundCode }, + { "state", newState }, + { "session_state", sessionState }, + }; + return Redirect($"{redirectUrl}{queryBuilder}"); } /// @@ -223,7 +248,9 @@ public ActionResult Callback( /// redirect_uri request parameter. /// client_id request parameter. /// client_secret request parameter. - [HttpPost("token")] + [HttpPost] + [TypeFilter(typeof(SmartOnFhirAuditLoggingFilterAttribute), Arguments = new object[] { AuditEventSubType.SmartOnFhirToken })] + [Route("token", Name = RouteNames.AadSmartOnFhirProxyToken)] public async Task Token( [FromForm(Name = "grant_type")] string grantType, [FromForm(Name = "code")] string compoundCode, @@ -245,7 +272,7 @@ public async Task Token( // TODO: Add handling of 'aud' -> 'resource', should that be an error or should translation be done? if (grantType != "authorization_code") { - List> fields = new List>(); + var fields = new List>(); foreach (var f in Request.Form) { fields.Add(new KeyValuePair(f.Key, f.Value)); @@ -279,9 +306,7 @@ public async Task Token( return BadRequest("Invalid compound authorization code"); } - Uri callbackUrl = new Uri( - Request.Scheme + "://" + Request.Host + "/AadSmartOnFhirProxy/callback/" + - Base64UrlEncoder.Encode(redirectUri.ToString())); + Uri callbackUrl = _urlResolver.ResolveRouteNameUrl(RouteNames.AadSmartOnFhirProxyCallback, new RouteValueDictionary { { "encodedRedirect", Base64UrlEncoder.Encode(redirectUri.ToString()) } }); // TODO: Deal with client secret in basic auth header var content = new FormUrlEncodedContent( @@ -289,12 +314,12 @@ public async Task Token( { new KeyValuePair("grant_type", grantType), new KeyValuePair("code", code), - new KeyValuePair("redirect_uri", callbackUrl.ToString()), + new KeyValuePair("redirect_uri", callbackUrl.AbsoluteUri), new KeyValuePair("client_id", clientId), new KeyValuePair("client_secret", clientSecret), }); - var response = await client.PostAsync(new Uri(_aadTokenEndpoint), content); + HttpResponseMessage response = await client.PostAsync(new Uri(_aadTokenEndpoint), content); if (!response.IsSuccessStatusCode) { @@ -321,13 +346,13 @@ public async Task Token( // Replace fully qualifies scopes with short scopes and replace $ string[] scopes = tokenResponse["scope"].ToString().Split(' '); - StringBuilder scopesBuilder = new StringBuilder(); + var scopesBuilder = new StringBuilder(); foreach (var s in scopes) { if (IsAbsoluteUrl(s)) { - Uri scopeUri = new Uri(s); + var scopeUri = new Uri(s); scopesBuilder.Append($"{scopeUri.Segments.Last().Replace('$', '/')} "); } else @@ -351,4 +376,4 @@ private static bool IsAbsoluteUrl(string url) return Uri.TryCreate(url, UriKind.Absolute, out _); } } -} \ No newline at end of file +} diff --git a/src/Microsoft.Health.Fhir.Api/Features/Audit/SmartOnFhirAuditLoggingFilterAttribute.cs b/src/Microsoft.Health.Fhir.Api/Features/Audit/SmartOnFhirAuditLoggingFilterAttribute.cs new file mode 100644 index 0000000000..5618d06879 --- /dev/null +++ b/src/Microsoft.Health.Fhir.Api/Features/Audit/SmartOnFhirAuditLoggingFilterAttribute.cs @@ -0,0 +1,77 @@ +// ------------------------------------------------------------------------------------------------- +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +// ------------------------------------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Linq; +using System.Net; +using EnsureThat; +using Microsoft.AspNetCore.Mvc.Filters; +using Microsoft.Extensions.Primitives; +using Microsoft.Health.Fhir.Core.Features.Context; + +namespace Microsoft.Health.Fhir.Api.Features.Audit +{ + [AttributeUsage(AttributeTargets.Method)] + internal class SmartOnFhirAuditLoggingFilterAttribute : ActionFilterAttribute + { + private const string ClientId = "client_id"; + private readonly IAuditLogger _auditLogger; + private readonly IFhirRequestContextAccessor _fhirRequestContextAccessor; + private readonly string _action; + + public SmartOnFhirAuditLoggingFilterAttribute(string action, IAuditLogger auditLogger, IFhirRequestContextAccessor fhirRequestContextAccessor) + { + EnsureArg.IsNotNullOrWhiteSpace(action, nameof(action)); + EnsureArg.IsNotNull(auditLogger, nameof(auditLogger)); + EnsureArg.IsNotNull(fhirRequestContextAccessor, nameof(fhirRequestContextAccessor)); + + _action = action; + _auditLogger = auditLogger; + _fhirRequestContextAccessor = fhirRequestContextAccessor; + } + + public override void OnActionExecuting(ActionExecutingContext context) + { + EnsureArg.IsNotNull(context, nameof(context)); + + _auditLogger.LogAudit( + AuditAction.Executing, + _action, + null, + _fhirRequestContextAccessor.FhirRequestContext.Uri, + null, + _fhirRequestContextAccessor.FhirRequestContext.CorrelationId, + GetClientIdFromQueryStringOrForm(context)); + + base.OnActionExecuting(context); + } + + public override void OnResultExecuted(ResultExecutedContext context) + { + EnsureArg.IsNotNull(context, nameof(context)); + + _auditLogger.LogAudit( + AuditAction.Executed, + _action, + null, + _fhirRequestContextAccessor.FhirRequestContext.Uri, + (HttpStatusCode)context.HttpContext.Response.StatusCode, + _fhirRequestContextAccessor.FhirRequestContext.CorrelationId, + GetClientIdFromQueryStringOrForm(context)); + + base.OnResultExecuted(context); + } + + private static ReadOnlyCollection> GetClientIdFromQueryStringOrForm(FilterContext context) + { + StringValues clientId = context.HttpContext.Request.HasFormContentType ? context.HttpContext.Request.Form[ClientId] : context.HttpContext.Request.Query[ClientId]; + + ReadOnlyCollection> claims = clientId.Select(x => new KeyValuePair(ClientId, x)).ToList().AsReadOnly(); + return claims; + } + } +} diff --git a/src/Microsoft.Health.Fhir.Api/Features/Routing/RouteNames.cs b/src/Microsoft.Health.Fhir.Api/Features/Routing/RouteNames.cs index ceb45447a5..a23a53a4b7 100644 --- a/src/Microsoft.Health.Fhir.Api/Features/Routing/RouteNames.cs +++ b/src/Microsoft.Health.Fhir.Api/Features/Routing/RouteNames.cs @@ -28,5 +28,11 @@ internal static class RouteNames internal const string SearchAllResourcesPost = "SearchAllResourcesPost"; internal const string SearchCompartmentByResourceType = "SearchCompartmentByResourceType"; + + internal const string AadSmartOnFhirProxyAuthorize = "AadSmartOnFhirProxyAuthorize"; + + internal const string AadSmartOnFhirProxyCallback = "AadSmartOnFhirProxyCallback"; + + internal const string AadSmartOnFhirProxyToken = "AadSmartOnFhirProxyToken"; } } diff --git a/src/Microsoft.Health.Fhir.Api/Features/Routing/UrlResolver.cs b/src/Microsoft.Health.Fhir.Api/Features/Routing/UrlResolver.cs index 3b93d8b79b..16b82b7e08 100644 --- a/src/Microsoft.Health.Fhir.Api/Features/Routing/UrlResolver.cs +++ b/src/Microsoft.Health.Fhir.Api/Features/Routing/UrlResolver.cs @@ -161,5 +161,18 @@ public Uri ResolveRouteUrl(IEnumerable> unsupportedSearchP return new Uri(uriString); } + + public Uri ResolveRouteNameUrl(string routeName, IDictionary routeValues) + { + var routeValueDictionary = new RouteValueDictionary(routeValues); + + var uriString = UrlHelper.RouteUrl( + routeName, + routeValueDictionary, + Request.Scheme, + Request.Host.Value); + + return new Uri(uriString); + } } } diff --git a/src/Microsoft.Health.Fhir.Api/Features/Security/SecurityProvider.cs b/src/Microsoft.Health.Fhir.Api/Features/Security/SecurityProvider.cs index eb9e95ead1..591fdfa733 100644 --- a/src/Microsoft.Health.Fhir.Api/Features/Security/SecurityProvider.cs +++ b/src/Microsoft.Health.Fhir.Api/Features/Security/SecurityProvider.cs @@ -7,6 +7,7 @@ using EnsureThat; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; +using Microsoft.Health.Fhir.Api.Features.Routing; using Microsoft.Health.Fhir.Core.Configs; using Microsoft.Health.Fhir.Core.Features.Conformance; using Microsoft.Health.Fhir.Core.Features.Routing; @@ -38,7 +39,7 @@ public void Build(ListedCapabilityStatement statement) { if (_securityConfiguration.EnableAadSmartOnFhirProxy) { - statement.AddProxyOAuthSecurityService(_urlResolver.ResolveMetadataUrl(false)); + statement.AddProxyOAuthSecurityService(_urlResolver, RouteNames.AadSmartOnFhirProxyAuthorize, RouteNames.AadSmartOnFhirProxyToken); } else { diff --git a/src/Microsoft.Health.Fhir.Core/Features/Conformance/CapabilityStatementExtensions.cs b/src/Microsoft.Health.Fhir.Core/Features/Conformance/CapabilityStatementExtensions.cs index 2aca2d4a8e..c5474d65a5 100644 --- a/src/Microsoft.Health.Fhir.Core/Features/Conformance/CapabilityStatementExtensions.cs +++ b/src/Microsoft.Health.Fhir.Core/Features/Conformance/CapabilityStatementExtensions.cs @@ -14,6 +14,7 @@ using Hl7.Fhir.Rest; using Microsoft.Extensions.Logging; using Microsoft.Health.Fhir.Core.Exceptions; +using Microsoft.Health.Fhir.Core.Features.Routing; using Microsoft.Health.Fhir.Core.Features.Security; using Newtonsoft.Json.Linq; using static Hl7.Fhir.Model.OperationOutcome; @@ -85,17 +86,19 @@ public static ListedCapabilityStatement BuildRestResourceComponent(this ListedCa return statement; } - public static ListedCapabilityStatement AddProxyOAuthSecurityService(this ListedCapabilityStatement statement, System.Uri metadataUri) + public static ListedCapabilityStatement AddProxyOAuthSecurityService(this ListedCapabilityStatement statement, IUrlResolver urlResolver, string authorizeRouteName, string tokenRouteName) { EnsureArg.IsNotNull(statement, nameof(statement)); + EnsureArg.IsNotNull(urlResolver, nameof(urlResolver)); + EnsureArg.IsNotNullOrWhiteSpace(authorizeRouteName, nameof(authorizeRouteName)); + EnsureArg.IsNotNullOrWhiteSpace(tokenRouteName, nameof(tokenRouteName)); var restComponent = statement.GetListedRestComponent(); var security = restComponent.Security ?? new CapabilityStatement.SecurityComponent(); security.Service.Add(Constants.RestfulSecurityServiceOAuth); - var baseurl = metadataUri.Scheme + "://" + metadataUri.Authority; - var tokenEndpoint = $"{baseurl}/AadSmartOnFhirProxy/token"; - var authorizationEndpoint = $"{baseurl}/AadSmartOnFhirProxy/authorize"; + var tokenEndpoint = urlResolver.ResolveRouteNameUrl(tokenRouteName, null); + var authorizationEndpoint = urlResolver.ResolveRouteNameUrl(authorizeRouteName, null); var smartExtension = new Extension() { diff --git a/src/Microsoft.Health.Fhir.Core/Features/Routing/IUrlResolver.cs b/src/Microsoft.Health.Fhir.Core/Features/Routing/IUrlResolver.cs index ddc4ff4abe..424ca3e12d 100644 --- a/src/Microsoft.Health.Fhir.Core/Features/Routing/IUrlResolver.cs +++ b/src/Microsoft.Health.Fhir.Core/Features/Routing/IUrlResolver.cs @@ -10,7 +10,7 @@ namespace Microsoft.Health.Fhir.Core.Features.Routing { /// - /// Provides functionalities to resolve URLs. + /// Provides functionality to resolve URLs. /// public interface IUrlResolver { @@ -36,5 +36,13 @@ public interface IUrlResolver /// The continuation token. /// The URL. Uri ResolveRouteUrl(IEnumerable> unsupportedSearchParams = null, string continuationToken = null); + + /// + /// Resolves the URL for the specified routeName. + /// + /// The route name to resolve. + /// Any route values to use in the route. + /// The URL. + Uri ResolveRouteNameUrl(string routeName, IDictionary routeValues); } } diff --git a/src/Microsoft.Health.Fhir.ValueSets/AuditEventSubType.cs b/src/Microsoft.Health.Fhir.ValueSets/AuditEventSubType.cs index 593e5a6aca..bb661af207 100644 --- a/src/Microsoft.Health.Fhir.ValueSets/AuditEventSubType.cs +++ b/src/Microsoft.Health.Fhir.ValueSets/AuditEventSubType.cs @@ -37,5 +37,11 @@ public static class AuditEventSubType public const string SearchSystem = "search-system"; public const string Capabilities = "capabilities"; + + public const string SmartOnFhirAuthorize = "smart-on-fhir-authorize"; + + public const string SmartOnFhirCallback = "smart-on-fhir-callback"; + + public const string SmartOnFhirToken = "smart-on-fhir-token"; } } diff --git a/src/Microsoft.Health.Fhir.Web/Startup.cs b/src/Microsoft.Health.Fhir.Web/Startup.cs index d4f4d77f92..4aa75d12db 100644 --- a/src/Microsoft.Health.Fhir.Web/Startup.cs +++ b/src/Microsoft.Health.Fhir.Web/Startup.cs @@ -7,7 +7,6 @@ using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.ApplicationInsights; namespace Microsoft.Health.Fhir.Web { diff --git a/test/Microsoft.Health.Fhir.Tests.E2E/Rest/Audit/AuditTests.cs b/test/Microsoft.Health.Fhir.Tests.E2E/Rest/Audit/AuditTests.cs index 7c902f846f..7bcd5c92f7 100644 --- a/test/Microsoft.Health.Fhir.Tests.E2E/Rest/Audit/AuditTests.cs +++ b/test/Microsoft.Health.Fhir.Tests.E2E/Rest/Audit/AuditTests.cs @@ -4,6 +4,7 @@ // ------------------------------------------------------------------------------------------------- using System; +using System.Collections.Generic; using System.Linq; using System.Net; using System.Net.Http; @@ -22,6 +23,7 @@ namespace Microsoft.Health.Fhir.Tests.E2E.Rest.Audit public class AuditTests : IClassFixture { private const string RequestIdHeaderName = "X-Request-Id"; + private const string ExpectedClaimKey = "appid"; private readonly AuditTestFixture _fixture; private readonly FhirClient _client; @@ -144,8 +146,8 @@ public async Task GivenAnExistingResource_WhenDeleted_ThenAuditLogEntriesShouldB // TODO: The resource type being logged here is incorrect. The issue is tracked by https://github.com/Microsoft/fhir-server/issues/334. Assert.Collection( _auditLogger.GetAuditEntriesByCorrelationId(correlationId), - ae => ValidateExecutingAuditEntry(ae, "delete", expectedUri, correlationId, expectedAppId), - ae => ValidateExecutedAuditEntry(ae, "delete", null, expectedUri, HttpStatusCode.NoContent, correlationId, expectedAppId)); + ae => ValidateExecutingAuditEntry(ae, "delete", expectedUri, correlationId, expectedAppId, ExpectedClaimKey), + ae => ValidateExecutedAuditEntry(ae, "delete", null, expectedUri, HttpStatusCode.NoContent, correlationId, expectedAppId, ExpectedClaimKey)); } [Fact] @@ -288,6 +290,55 @@ await ExecuteAndValidate( expectedAppId: null); } + [Fact] + public async Task GivenASmartOnFhirRequest_WhenAuthorizeIsCalled_TheAuditLogEntriesShouldBeCreated() + { + const string pathSegment = "AadSmartOnFhirProxy/authorize?client_id=1234&response_type=json&redirect_uri=httptest&aud=localhost"; + await ExecuteAndValidate( + async () => await _client.HttpClient.GetAsync(pathSegment), + "smart-on-fhir-authorize", + pathSegment, + HttpStatusCode.Redirect, + "1234", + "client_id"); + } + + [Fact] + public async Task GivenASmartOnFhirRequest_WhenCallbackIsCalled_TheAuditLogEntriesShouldBeCreated() + { + const string pathSegment = "AadSmartOnFhirProxy/callback/aHR0cHM6Ly9sb2NhbGhvc3Q=?code=1234&state=1234&session_state=1234"; + await ExecuteAndValidate( + async () => await _client.HttpClient.GetAsync(pathSegment), + "smart-on-fhir-callback", + pathSegment, + HttpStatusCode.BadRequest, + null, + null); + } + + [Fact] + public async Task GivenASmartOnFhirRequest_WhenTokenIsCalled_TheAuditLogEntriesShouldBeCreated() + { + const string pathSegment = "AadSmartOnFhirProxy/token"; + var formFields = new List> + { + new KeyValuePair("client_id", "1234"), + new KeyValuePair("grant_type", "grantType"), + new KeyValuePair("code", "code"), + new KeyValuePair("redirect_uri", "redirectUri"), + new KeyValuePair("client_secret", "client_secret"), + }; + + var content = new FormUrlEncodedContent(formFields); + await ExecuteAndValidate( + async () => await _client.HttpClient.PostAsync(pathSegment, content), + "smart-on-fhir-token", + pathSegment, + HttpStatusCode.BadRequest, + "1234", + "client_id"); + } + [Fact] public async Task GivenAResource_WhenNotAuthorized_ThenAuditLogEntriesShouldBeCreated() { @@ -320,8 +371,30 @@ private async Task ExecuteAndValidate(Func>> action, str Assert.Collection( _auditLogger.GetAuditEntriesByCorrelationId(correlationId), - ae => ValidateExecutingAuditEntry(ae, expectedAction, expectedUri, correlationId, expectedAppId), - ae => ValidateExecutedAuditEntry(ae, expectedAction, expectedResourceType, expectedUri, expectedStatusCode, correlationId, expectedAppId)); + ae => ValidateExecutingAuditEntry(ae, expectedAction, expectedUri, correlationId, expectedAppId, ExpectedClaimKey), + ae => ValidateExecutedAuditEntry(ae, expectedAction, expectedResourceType, expectedUri, expectedStatusCode, correlationId, expectedAppId, ExpectedClaimKey)); + } + + private async Task ExecuteAndValidate(Func> action, string expectedAction, string expectedPathSegment, HttpStatusCode expectedStatusCode, string expectedClaimValue, string expectedClaimKey) + { + if (!_fixture.IsUsingInProcTestServer) + { + // This test only works with the in-proc server with customized middleware pipeline + return; + } + + HttpResponseMessage response = await action(); + + string correlationId = response.Headers.GetValues(RequestIdHeaderName).FirstOrDefault(); + + Assert.NotNull(correlationId); + + var expectedUri = new Uri($"http://localhost/{expectedPathSegment}"); + + Assert.Collection( + _auditLogger.GetAuditEntriesByCorrelationId(correlationId), + ae => ValidateExecutingAuditEntry(ae, expectedAction, expectedUri, correlationId, expectedClaimValue, expectedClaimKey), + ae => ValidateExecutedAuditEntry(ae, expectedAction, null, expectedUri, expectedStatusCode, correlationId, expectedClaimValue, expectedClaimKey)); } private async Task ExecuteAndValidate(Func clientSetup, HttpStatusCode expectedStatusCode, string expectedAppId) @@ -349,20 +422,20 @@ private async Task ExecuteAndValidate(Func clientSetup, HttpSt Assert.Collection( _auditLogger.GetAuditEntriesByCorrelationId(correlationId), - ae => ValidateExecutedAuditEntry(ae, "read", ResourceType.Patient, expectedUri, expectedStatusCode, correlationId, expectedAppId)); + ae => ValidateExecutedAuditEntry(ae, "read", ResourceType.Patient, expectedUri, expectedStatusCode, correlationId, expectedAppId, ExpectedClaimKey)); } - private void ValidateExecutingAuditEntry(AuditEntry auditEntry, string expectedAction, Uri expectedUri, string expectedCorrelationId, string expectedAppId) + private void ValidateExecutingAuditEntry(AuditEntry auditEntry, string expectedAction, Uri expectedUri, string expectedCorrelationId, string expectedClaimValue, string expectedClaimKey) { - ValidateAuditEntry(auditEntry, AuditAction.Executing, expectedAction, null, expectedUri, null, expectedCorrelationId, expectedAppId); + ValidateAuditEntry(auditEntry, AuditAction.Executing, expectedAction, null, expectedUri, null, expectedCorrelationId, expectedClaimValue, expectedClaimKey); } - private void ValidateExecutedAuditEntry(AuditEntry auditEntry, string expectedAction, ResourceType? expectedResourceType, Uri expectedUri, HttpStatusCode? expectedStatusCode, string expectedCorrelationId, string expectedAppId) + private void ValidateExecutedAuditEntry(AuditEntry auditEntry, string expectedAction, ResourceType? expectedResourceType, Uri expectedUri, HttpStatusCode? expectedStatusCode, string expectedCorrelationId, string expectedClaimValue, string expectedClaimKey) { - ValidateAuditEntry(auditEntry, AuditAction.Executed, expectedAction, expectedResourceType, expectedUri, expectedStatusCode, expectedCorrelationId, expectedAppId); + ValidateAuditEntry(auditEntry, AuditAction.Executed, expectedAction, expectedResourceType, expectedUri, expectedStatusCode, expectedCorrelationId, expectedClaimValue, expectedClaimKey); } - private void ValidateAuditEntry(AuditEntry auditEntry, AuditAction expectedAuditAction, string expectedAction, ResourceType? expectedResourceType, Uri expectedUri, HttpStatusCode? expectedStatusCode, string expectedCorrelationId, string expectedAppId) + private void ValidateAuditEntry(AuditEntry auditEntry, AuditAction expectedAuditAction, string expectedAction, ResourceType? expectedResourceType, Uri expectedUri, HttpStatusCode? expectedStatusCode, string expectedCorrelationId, string expectedClaimValue, string expectedClaimKey) { Assert.NotNull(auditEntry); Assert.Equal(expectedAuditAction, auditEntry.AuditAction); @@ -372,11 +445,11 @@ private void ValidateAuditEntry(AuditEntry auditEntry, AuditAction expectedAudit Assert.Equal(expectedStatusCode, auditEntry.StatusCode); Assert.Equal(expectedCorrelationId, auditEntry.CorrelationId); - if (expectedAppId != null) + if (expectedClaimValue != null) { Assert.Equal(1, auditEntry.Claims.Count); - Assert.Equal("appid", auditEntry.Claims.Single().Key); - Assert.Equal(expectedAppId, auditEntry.Claims.Single().Value); + Assert.Equal(expectedClaimKey, auditEntry.Claims.Single().Key); + Assert.Equal(expectedClaimValue, auditEntry.Claims.Single().Value); } else {