From 64053df1226e66799b2447f0b460f5c9b226ad1b Mon Sep 17 00:00:00 2001 From: slisson Date: Wed, 11 Dec 2024 11:25:39 +0100 Subject: [PATCH] fix(authorization): cache remote keys RemoteJWKSet already caches keys from remote URLs, but all instances of key sources weren't reused. --- .../modelix/authorization/ModelixJWTUtil.kt | 89 +++++++++++++------ 1 file changed, 60 insertions(+), 29 deletions(-) diff --git a/authorization/src/main/kotlin/org/modelix/authorization/ModelixJWTUtil.kt b/authorization/src/main/kotlin/org/modelix/authorization/ModelixJWTUtil.kt index 5101f5e732..e5bfc752ab 100644 --- a/authorization/src/main/kotlin/org/modelix/authorization/ModelixJWTUtil.kt +++ b/authorization/src/main/kotlin/org/modelix/authorization/ModelixJWTUtil.kt @@ -27,6 +27,7 @@ import com.nimbusds.jose.util.Resource import com.nimbusds.jwt.JWTClaimsSet import com.nimbusds.jwt.JWTParser import com.nimbusds.jwt.proc.DefaultJWTProcessor +import com.nimbusds.jwt.proc.JWTProcessor import io.ktor.client.HttpClient import io.ktor.client.request.get import io.ktor.client.statement.bodyAsText @@ -46,7 +47,6 @@ import java.util.Base64 import java.util.Date import java.util.UUID import javax.crypto.spec.SecretKeySpec -import kotlin.String class ModelixJWTUtil { private var hmacKeys = LinkedHashMap() @@ -57,6 +57,45 @@ class ModelixJWTUtil { private var ktorClient: HttpClient? = null var accessControlDataProvider: IAccessControlDataProvider = EmptyAccessControlDataProvider() + private var jwtProcessor: JWTProcessor? = null + + @Synchronized + private fun getOrCreateJwtProcessor(): JWTProcessor { + return jwtProcessor ?: DefaultJWTProcessor().also { processor -> + val keySelectors: List> = hmacKeys.map { it.toPair() }.map { + SingleKeyJWSKeySelector(it.first, SecretKeySpec(it.second, it.first.name)) + } + jwksUrls.map { + val client = this.ktorClient + if (client == null) { + JWSAlgorithmFamilyJWSKeySelector.fromJWKSetURL(it) + } else { + JWSAlgorithmFamilyJWSKeySelector.fromJWKSource(RemoteJWKSet(it, KtorResourceRetriever(client))) + } + } + rsaPublicKeys.map { + JWSAlgorithmFamilyJWSKeySelector.fromJWKSource(ImmutableJWKSet(JWKSet(it.toPublicJWK()))) + } + + processor.jwsKeySelector = if (keySelectors.size == 1) keySelectors.single() else CompositeJWSKeySelector(keySelectors) + + val expectedKeyId = this.expectedKeyId + if (expectedKeyId != null) { + processor.jwsVerifierFactory = object : DefaultJWSVerifierFactory() { + override fun createJWSVerifier(header: JWSHeader, key: Key): JWSVerifier { + if (header.keyID != expectedKeyId) { + throw BadJOSEException("Invalid key ID. [expected=$expectedKeyId, actual=${header.keyID}]") + } + return super.createJWSVerifier(header, key) + } + } + } + }.also { jwtProcessor = it } + } + + private fun resetJwtProcess() { + jwtProcessor = null + } + + @Synchronized fun canVerifyTokens(): Boolean { return hmacKeys.isNotEmpty() || rsaPublicKeys.isNotEmpty() || jwksUrls.isNotEmpty() } @@ -64,21 +103,27 @@ class ModelixJWTUtil { /** * Tokens are only valid if they are signed with this key. */ + @Synchronized fun requireKeyId(id: String) { expectedKeyId = id } + @Synchronized fun useKtorClient(client: HttpClient) { + resetJwtProcess() this.ktorClient = client.config { expectSuccess = true } } + @Synchronized fun addJwksUrl(url: String) { addJwksUrl(URI(url).toURL()) } + @Synchronized fun addJwksUrl(url: URL) { + resetJwtProcess() jwksUrls += url } @@ -91,28 +136,37 @@ class ModelixJWTUtil { addHmacKey(key.toByteArray().ensureMinSecretLength(algorithm), algorithm) } + @Synchronized fun addPublicKey(key: JWK) { requireNotNull(key.keyID) { "Key doesn't specify a key ID: $key" } requireNotNull(key.algorithm) { "Key doesn't specify an algorithm: $key" } + resetJwtProcess() rsaPublicKeys.add(key) } + @Synchronized fun setRSAPrivateKey(key: JWK) { requireNotNull(key.keyID) { "Key doesn't specify a key ID: $key" } requireNotNull(key.algorithm) { "Key doesn't specify an algorithm: $key" } + resetJwtProcess() this.rsaPrivateKey = key addPublicKey(key.toPublicJWK()) } + @Synchronized private fun addHmacKey(key: ByteArray, algorithm: JWSAlgorithm) { + resetJwtProcess() hmacKeys[algorithm] = key } + @Synchronized fun getPublicJWKS(): JWKSet { return JWKSet(listOfNotNull(rsaPrivateKey)).toPublicJWKSet() } + @Synchronized fun loadKeysFromEnvironment() { + resetJwtProcess() System.getenv().filter { it.key.startsWith("MODELIX_JWK_FILE") }.values.forEach { File(it).walk().forEach { file -> when (file.extension) { @@ -127,6 +181,7 @@ class ModelixJWTUtil { .forEach { addJwksUrl(URI(it).toURL()) } } + @Synchronized fun createAccessToken(user: String, grantedPermissions: List, additionalTokenContent: (TokenBuilder) -> Unit = {}): String { val signer: JWSSigner val algorithm: JWSAlgorithm @@ -174,6 +229,7 @@ class ModelixJWTUtil { return token.claims[ModelixTokenConstants.PERMISSIONS]?.asList(String::class.java) } + @Synchronized fun loadGrantedPermissions(token: DecodedJWT, evaluator: PermissionEvaluator) { val permissions = extractPermissions(token) @@ -252,6 +308,7 @@ class ModelixJWTUtil { } private fun loadJwk(key: JWK) { + resetJwtProcess() if (key.isPrivate) { setRSAPrivateKey(key) } else { @@ -259,35 +316,9 @@ class ModelixJWTUtil { } } + @Synchronized fun verifyToken(token: String) { - DefaultJWTProcessor().also { processor -> - val keySelectors: List> = hmacKeys.map { it.toPair() }.map { - SingleKeyJWSKeySelector(it.first, SecretKeySpec(it.second, it.first.name)) - } + jwksUrls.map { - val client = this.ktorClient - if (client == null) { - JWSAlgorithmFamilyJWSKeySelector.fromJWKSetURL(it) - } else { - JWSAlgorithmFamilyJWSKeySelector.fromJWKSource(RemoteJWKSet(it, KtorResourceRetriever(client))) - } - } + rsaPublicKeys.map { - JWSAlgorithmFamilyJWSKeySelector.fromJWKSource(ImmutableJWKSet(JWKSet(it.toPublicJWK()))) - } - - processor.jwsKeySelector = if (keySelectors.size == 1) keySelectors.single() else CompositeJWSKeySelector(keySelectors) - - val expectedKeyId = this.expectedKeyId - if (expectedKeyId != null) { - processor.jwsVerifierFactory = object : DefaultJWSVerifierFactory() { - override fun createJWSVerifier(header: JWSHeader, key: Key): JWSVerifier { - if (header.keyID != expectedKeyId) { - throw BadJOSEException("Invalid key ID. [expected=$expectedKeyId, actual=${header.keyID}]") - } - return super.createJWSVerifier(header, key) - } - } - } - }.process(JWTParser.parse(token), null) + getOrCreateJwtProcessor().process(JWTParser.parse(token), null) } class TokenBuilder(private val builder: JWTClaimsSet.Builder) {