Skip to content

Commit

Permalink
Merge pull request #2819 from Tharsanan1/fix-org
Browse files Browse the repository at this point in the history
Fix backend jwt
  • Loading branch information
Krishanx92 authored Feb 15, 2025
2 parents f24da7e + f7b0008 commit 7f9db94
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 37 deletions.
81 changes: 44 additions & 37 deletions gateway/enforcer/internal/extproc/ext_proc.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ type ExternalProcessingServer struct {
apiStore *datastore.APIStore
subscriptionApplicationDatastore *datastore.SubscriptionApplicationDataStore
ratelimitHelper *ratelimit.AIRatelimitHelper
requestConfigHolder *requestconfig.Holder
cfg *config.Server
jwtTransformer *transformer.JWTTransformer
modelBasedRoundRobinTracker *datastore.ModelBasedRoundRobinTracker
Expand Down Expand Up @@ -124,7 +123,7 @@ func StartExternalProcessingServer(cfg *config.Server, apiStore *datastore.APISt
}

ratelimitHelper := ratelimit.NewAIRatelimitHelper(cfg)
envoy_service_proc_v3.RegisterExternalProcessorServer(server, &ExternalProcessingServer{cfg.Logger, apiStore, subAppDatastore, ratelimitHelper, nil, cfg, jwtTransformer, modelBasedRoundRobinTracker})
envoy_service_proc_v3.RegisterExternalProcessorServer(server, &ExternalProcessingServer{cfg.Logger, apiStore, subAppDatastore, ratelimitHelper, cfg, jwtTransformer, modelBasedRoundRobinTracker})
listener, err := net.Listen("tcp", fmt.Sprintf(":%s", cfg.ExternalProcessingPort))
if err != nil {
cfg.Logger.Error(err, fmt.Sprintf("Failed to listen on port: %s", cfg.ExternalProcessingPort))
Expand Down Expand Up @@ -170,13 +169,14 @@ func (s *ExternalProcessingServer) Process(srv envoy_service_proc_v3.ExternalPro
}

resp := &envoy_service_proc_v3.ProcessingResponse{}
requestConfigHolder := &requestconfig.Holder{}
// log req.Attributes
s.log.Info(fmt.Sprintf("Attributes: %+v", req.Attributes))
dynamicMetadataKeyValuePairs := make(map[string]string)
switch v := req.Request.(type) {
case *envoy_service_proc_v3.ProcessingRequest_RequestHeaders:
requestConfigHolder := &requestconfig.Holder{}
attributes, err := extractExternalProcessingXDSRouteMetadataAttributes(req.GetAttributes())
requestConfigHolder.ExternalProcessingEnvoyAttributes = attributes
if err != nil {
s.log.Error(err, "failed to extract context attributes")
resp = &envoy_service_proc_v3.ProcessingResponse{
Expand Down Expand Up @@ -213,6 +213,10 @@ func (s *ExternalProcessingServer) Process(srv envoy_service_proc_v3.ExternalPro
}
apiKey := util.PrepareAPIKey(attributes.VHost, attributes.BasePath, attributes.APIVersion)
requestConfigHolder.MatchedAPI = s.apiStore.GetMatchedAPI(util.PrepareAPIKey(attributes.VHost, attributes.BasePath, attributes.APIVersion))
// Do not remove or modify this nil check. It is necessary to avoid nil pointer dereference.
if requestConfigHolder.MatchedAPI == nil {
break
}
dynamicMetadataKeyValuePairs[customOrgMetadataKey] = requestConfigHolder.MatchedAPI.OrganizationID

dynamicMetadataKeyValuePairs[matchedAPIMetadataKey] = apiKey
Expand All @@ -224,8 +228,8 @@ func (s *ExternalProcessingServer) Process(srv envoy_service_proc_v3.ExternalPro
dynamicMetadataKeyValuePairs[analytics.APIOrganizationIDKey] = requestConfigHolder.MatchedAPI.OrganizationID
dynamicMetadataKeyValuePairs[analytics.APICreatorTenantDomainKey] = requestConfigHolder.MatchedAPI.OrganizationID

requestConfigHolder.ExternalProcessingEnvoyAttributes = attributes
if requestConfigHolder.MatchedAPI != nil && requestConfigHolder.MatchedAPI.APIDefinitionPath != "" {

if requestConfigHolder.MatchedAPI.APIDefinitionPath != "" {
definitionPath := requestConfigHolder.MatchedAPI.APIDefinitionPath
s.cfg.Logger.Info(fmt.Sprintf("definition Path: %v", definitionPath))
fullPath := requestConfigHolder.MatchedAPI.BasePath + requestConfigHolder.MatchedAPI.APIDefinitionPath
Expand Down Expand Up @@ -280,15 +284,18 @@ func (s *ExternalProcessingServer) Process(srv envoy_service_proc_v3.ExternalPro
s.cfg.Logger.Info(fmt.Sprintf("Metadata context : %+v", req.GetMetadataContext()))

requestConfigHolder.MatchedResource = httpHandler.GetMatchedResource(requestConfigHolder.MatchedAPI, *requestConfigHolder.ExternalProcessingEnvoyAttributes)
if requestConfigHolder.MatchedResource != nil {
requestConfigHolder.MatchedResource.RouteMetadataAttributes = attributes
dynamicMetadataKeyValuePairs[matchedResourceMetadataKey] = requestConfigHolder.MatchedResource.GetResourceIdentifier()
dynamicMetadataKeyValuePairs[analytics.APIResourceTemplateKey] = requestConfigHolder.MatchedResource.Path
s.log.Info(fmt.Sprintf("Matched Resource Endpoints: %+v", requestConfigHolder.MatchedResource.Endpoints))
if requestConfigHolder.MatchedResource.Endpoints != nil && len(requestConfigHolder.MatchedResource.Endpoints.URLs) > 0 {
dynamicMetadataKeyValuePairs[analytics.DestinationKey] = requestConfigHolder.MatchedResource.Endpoints.URLs[0]
}
// Do not remove or modify this nil check. It is necessary to avoid nil pointer dereference.
if requestConfigHolder.MatchedResource == nil {
break
}
requestConfigHolder.MatchedResource.RouteMetadataAttributes = attributes
dynamicMetadataKeyValuePairs[matchedResourceMetadataKey] = requestConfigHolder.MatchedResource.GetResourceIdentifier()
dynamicMetadataKeyValuePairs[analytics.APIResourceTemplateKey] = requestConfigHolder.MatchedResource.Path
s.log.Info(fmt.Sprintf("Matched Resource Endpoints: %+v", requestConfigHolder.MatchedResource.Endpoints))
if requestConfigHolder.MatchedResource.Endpoints != nil && len(requestConfigHolder.MatchedResource.Endpoints.URLs) > 0 {
dynamicMetadataKeyValuePairs[analytics.DestinationKey] = requestConfigHolder.MatchedResource.Endpoints.URLs[0]
}

metadata, err := extractExternalProcessingMetadata(req.GetMetadataContext())
if err != nil {
s.log.Error(err, "failed to extract context metadata")
Expand All @@ -301,7 +308,7 @@ func (s *ExternalProcessingServer) Process(srv envoy_service_proc_v3.ExternalPro
// s.log.Info(fmt.Sprintf("Matched Resource: %v", requestConfigHolder.MatchedResource))
// s.log.Info(fmt.Sprintf("req holderrr: %+v\n s: %+v", &requestConfigHolder, &s))
s.log.Info(fmt.Sprintf("req holderrr: %+v\n s: %+v", requestConfigHolder, s))
if requestConfigHolder.MatchedResource != nil && requestConfigHolder.MatchedResource.AuthenticationConfig != nil && !requestConfigHolder.MatchedResource.AuthenticationConfig.Disabled && !requestConfigHolder.MatchedAPI.DisableAuthentication {
if requestConfigHolder.MatchedResource.AuthenticationConfig != nil && !requestConfigHolder.MatchedResource.AuthenticationConfig.Disabled && !requestConfigHolder.MatchedAPI.DisableAuthentication {
jwtValidationInfo := s.jwtTransformer.TransformJWTClaims(requestConfigHolder.MatchedAPI.OrganizationID, requestConfigHolder.ExternalProcessingEnvoyMetadata)
requestConfigHolder.JWTValidationInfo = &jwtValidationInfo
s.log.Sugar().Infof("jwtValidation==%v", jwtValidationInfo)
Expand All @@ -327,7 +334,7 @@ func (s *ExternalProcessingServer) Process(srv envoy_service_proc_v3.ExternalPro
}
}
backendJWT := ""
if requestConfigHolder.MatchedAPI != nil && requestConfigHolder.MatchedAPI.BackendJwtConfiguration != nil && requestConfigHolder.MatchedAPI.BackendJwtConfiguration.Enabled {
if requestConfigHolder.MatchedAPI.BackendJwtConfiguration != nil && requestConfigHolder.MatchedAPI.BackendJwtConfiguration.Enabled {
backendJWT = jwtbackend.CreateBackendJWT(requestConfigHolder, s.cfg)
s.log.Sugar().Infof("generated backendJWT==%v", backendJWT)
}
Expand All @@ -348,7 +355,7 @@ func (s *ExternalProcessingServer) Process(srv envoy_service_proc_v3.ExternalPro
dynamicMetadataKeyValuePairs[matchedSubscriptionMetadataKey] = requestConfigHolder.MatchedSubscription.UUID
}

if requestConfigHolder.MatchedAPI != nil && requestConfigHolder.MatchedAPI.EndpointSecurity != nil {
if requestConfigHolder.MatchedAPI.EndpointSecurity != nil {
s.cfg.Logger.Info(fmt.Sprintf("Inside API Level Endpoint Security: %+v", requestConfigHolder.MatchedAPI.EndpointSecurity))
for _, es := range requestConfigHolder.MatchedAPI.EndpointSecurity {
if es.Enabled {
Expand All @@ -374,7 +381,7 @@ func (s *ExternalProcessingServer) Process(srv envoy_service_proc_v3.ExternalPro
}
}

if requestConfigHolder.MatchedResource != nil && requestConfigHolder.MatchedResource.EndpointSecurity != nil {
if requestConfigHolder.MatchedResource.EndpointSecurity != nil {
s.cfg.Logger.Info(fmt.Sprintf("Resource Level Endpoint Security: %+v", requestConfigHolder.MatchedResource.EndpointSecurity))
for _, es := range requestConfigHolder.MatchedResource.EndpointSecurity {
if es.Enabled {
Expand Down Expand Up @@ -887,7 +894,7 @@ func (s *ExternalProcessingServer) Process(srv envoy_service_proc_v3.ExternalPro
s.log.Info(fmt.Sprintf("Unknown Request type %v\n", v))
}
// Set dynamic metadata
dynamicMetadata, err := buildDynamicMetadata(s.prepareMetadataKeyValuePairAndAddTo(dynamicMetadataKeyValuePairs))
dynamicMetadata, err := buildDynamicMetadata(prepareMetadataKeyValuePairAndAddTo(dynamicMetadataKeyValuePairs, requestConfigHolder, s.cfg))
if err != nil {
s.log.Error(err, "failed to build dynamic metadata")
} else {
Expand Down Expand Up @@ -1146,32 +1153,32 @@ func buildDynamicMetadata(keyValuePairs *map[string]string) (*structpb.Struct, e
return rootStruct, nil
}

func (s *ExternalProcessingServer) prepareMetadataKeyValuePairAndAddTo(metadataKeyValuePair map[string]string) *map[string]string {
if s.requestConfigHolder != nil && s.requestConfigHolder.MatchedAPI != nil {
metadataKeyValuePair[analytics.APIIDKey] = s.requestConfigHolder.MatchedAPI.UUID
metadataKeyValuePair[analytics.APIContextKey] = s.requestConfigHolder.MatchedAPI.BasePath
metadataKeyValuePair[organizationMetadataKey] = s.requestConfigHolder.MatchedAPI.OrganizationID
metadataKeyValuePair[analytics.APINameKey] = s.requestConfigHolder.MatchedAPI.Name
metadataKeyValuePair[analytics.APIVersionKey] = s.requestConfigHolder.MatchedAPI.Version
metadataKeyValuePair[analytics.APITypeKey] = s.requestConfigHolder.MatchedAPI.APIType
func prepareMetadataKeyValuePairAndAddTo(metadataKeyValuePair map[string]string, requestConfigHolder *requestconfig.Holder, cfg *config.Server) *map[string]string {
if requestConfigHolder != nil && requestConfigHolder.MatchedAPI != nil {
metadataKeyValuePair[analytics.APIIDKey] = requestConfigHolder.MatchedAPI.UUID
metadataKeyValuePair[analytics.APIContextKey] = requestConfigHolder.MatchedAPI.BasePath
metadataKeyValuePair[organizationMetadataKey] = requestConfigHolder.MatchedAPI.OrganizationID
metadataKeyValuePair[analytics.APINameKey] = requestConfigHolder.MatchedAPI.Name
metadataKeyValuePair[analytics.APIVersionKey] = requestConfigHolder.MatchedAPI.Version
metadataKeyValuePair[analytics.APITypeKey] = requestConfigHolder.MatchedAPI.APIType
// metadataKeyValuePair[analytics.ApiCreatorKey] = s.requestConfigHolder.MatchedAPI.Creator
// metadataKeyValuePair[analytics.ApiCreatorTenantDomainKey] = s.requestConfigHolder.MatchedAPI.CreatorTenant
metadataKeyValuePair[analytics.APIOrganizationIDKey] = s.requestConfigHolder.MatchedAPI.OrganizationID
metadataKeyValuePair[analytics.APIOrganizationIDKey] = requestConfigHolder.MatchedAPI.OrganizationID

metadataKeyValuePair[analytics.CorrelationIDKey] = s.requestConfigHolder.ExternalProcessingEnvoyAttributes.CorrelationID
metadataKeyValuePair[analytics.RegionKey] = s.cfg.EnforcerRegionID
metadataKeyValuePair[analytics.CorrelationIDKey] = requestConfigHolder.ExternalProcessingEnvoyAttributes.CorrelationID
metadataKeyValuePair[analytics.RegionKey] = cfg.EnforcerRegionID
// metadataKeyValuePair[analytics.UserAgentKey] = s.requestConfigHolder.Metadata.UserAgent
// metadataKeyValuePair[analytics.ClientIpKey] = s.requestConfigHolder.Metadata.ClientIP
// metadataKeyValuePair[analytics.ApiResourceTemplateKey] = s.requestConfigHolder.ApiResourceTemplate
// metadataKeyValuePair[analytics.Destination] = s.requestConfigHolder.Metadata.Destination
metadataKeyValuePair[analytics.APIEnvironmentKey] = s.requestConfigHolder.MatchedAPI.Environment

if s.requestConfigHolder.MatchedApplication != nil {
metadataKeyValuePair[analytics.AppIDKey] = s.requestConfigHolder.MatchedApplication.UUID
metadataKeyValuePair[analytics.AppUUIDKey] = s.requestConfigHolder.MatchedApplication.UUID
metadataKeyValuePair[analytics.AppKeyTypeKey] = s.requestConfigHolder.MatchedAPI.EnvType
metadataKeyValuePair[analytics.AppNameKey] = s.requestConfigHolder.MatchedApplication.Name
metadataKeyValuePair[analytics.AppOwnerKey] = s.requestConfigHolder.MatchedApplication.Owner
metadataKeyValuePair[analytics.APIEnvironmentKey] = requestConfigHolder.MatchedAPI.Environment

if requestConfigHolder.MatchedApplication != nil {
metadataKeyValuePair[analytics.AppIDKey] = requestConfigHolder.MatchedApplication.UUID
metadataKeyValuePair[analytics.AppUUIDKey] = requestConfigHolder.MatchedApplication.UUID
metadataKeyValuePair[analytics.AppKeyTypeKey] = requestConfigHolder.MatchedAPI.EnvType
metadataKeyValuePair[analytics.AppNameKey] = requestConfigHolder.MatchedApplication.Name
metadataKeyValuePair[analytics.AppOwnerKey] = requestConfigHolder.MatchedApplication.Owner
}
}
return &metadataKeyValuePair
Expand Down
12 changes: 12 additions & 0 deletions gateway/enforcer/internal/jwtbackend/jwt_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ const (
dialectURI = "http://wso2.org/claims/"
sha256WithRSA = "SHA256withRSA"
)
var restrictedClaims = []string{"iss", "sub", "aud", "exp", "nbf", "iat", "jti", "application", "tierInfo", "subscribedAPIs", "aut"}

// CreateBackendJWT creates a JWT token for the backend.
func CreateBackendJWT(rch *requestconfig.Holder, cfg *config.Server) string {
Expand Down Expand Up @@ -89,6 +90,17 @@ func CreateBackendJWT(rch *requestconfig.Holder, cfg *config.Server) string {
Type: "string",
}
}
for claim, claimValue := range rch.JWTValidationInfo.Claims {
if !util.Contains(restrictedClaims, claim) {
if claimValue, ok := claimValue.(string); ok {
customClaims[claim] = &dto.ClaimValue{
Value: claimValue,
Type: "string",
}
}
}
}

}
signatureAlgorithm := bjc.SignatureAlgorithm
if signatureAlgorithm != "NONE" && signatureAlgorithm != sha256WithRSA {
Expand Down
28 changes: 28 additions & 0 deletions gateway/enforcer/internal/util/common.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright (c) 2025, WSO2 LLC. (http://www.wso2.org) All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package util

// Contains checks if a string is present in a slice.
func Contains(slice []string, item string) bool {
for _, s := range slice {
if s == item {
return true
}
}
return false
}

0 comments on commit 7f9db94

Please sign in to comment.