Skip to content

Commit

Permalink
fix(authorization): reload keys when the file changes
Browse files Browse the repository at this point in the history
If the Kubernetes secret containing the RSA key changed, the container usually isn't restarted,
so we have to make sure we don't use an outdated key.
  • Loading branch information
slisson committed Dec 16, 2024
1 parent 59ca7b6 commit 2b80425
Show file tree
Hide file tree
Showing 2 changed files with 259 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@ import com.nimbusds.jose.crypto.MACSigner
import com.nimbusds.jose.crypto.RSASSASigner
import com.nimbusds.jose.crypto.factories.DefaultJWSVerifierFactory
import com.nimbusds.jose.jwk.JWK
import com.nimbusds.jose.jwk.JWKMatcher
import com.nimbusds.jose.jwk.JWKSelector
import com.nimbusds.jose.jwk.JWKSet
import com.nimbusds.jose.jwk.KeyType
import com.nimbusds.jose.jwk.KeyUse
import com.nimbusds.jose.jwk.RSAKey
import com.nimbusds.jose.jwk.gen.RSAKeyGenerator
import com.nimbusds.jose.jwk.source.ImmutableJWKSet
import com.nimbusds.jose.jwk.source.JWKSource
import com.nimbusds.jose.jwk.source.RemoteJWKSet
import com.nimbusds.jose.proc.BadJOSEException
import com.nimbusds.jose.proc.JWSAlgorithmFamilyJWSKeySelector
Expand Down Expand Up @@ -47,32 +50,26 @@ import java.util.Base64
import java.util.Date
import java.util.UUID
import javax.crypto.spec.SecretKeySpec
import kotlin.time.Duration
import kotlin.time.Duration.Companion.seconds

class ModelixJWTUtil {
private var hmacKeys = LinkedHashMap<JWSAlgorithm, ByteArray>()
private var rsaPrivateKey: JWK? = null
private var rsaPublicKeys = ArrayList<JWK>()
private val jwksUrls = LinkedHashSet<URL>()
private val hmacKeys = LinkedHashMap<JWSAlgorithm, ByteArray>()
private val jwkSources = ArrayList<JWKSource<SecurityContext>>()
private var expectedKeyId: String? = null
private var ktorClient: HttpClient? = null
var accessControlDataProvider: IAccessControlDataProvider = EmptyAccessControlDataProvider()

private var jwtProcessor: JWTProcessor<SecurityContext>? = null
var fileRefreshTime: Duration = 5.seconds

@Synchronized
private fun getOrCreateJwtProcessor(): JWTProcessor<SecurityContext> {
return jwtProcessor ?: DefaultJWTProcessor<SecurityContext>().also { processor ->
val keySelectors: List<JWSKeySelector<SecurityContext>> = hmacKeys.map { it.toPair() }.map {
SingleKeyJWSKeySelector<SecurityContext>(it.first, SecretKeySpec(it.second, it.first.name))
} + jwksUrls.map {
val client = this.ktorClient
if (client == null) {
JWSAlgorithmFamilyJWSKeySelector.fromJWKSetURL<SecurityContext>(it)
} else {
JWSAlgorithmFamilyJWSKeySelector.fromJWKSource<SecurityContext>(RemoteJWKSet(it, KtorResourceRetriever(client)))
}
} + rsaPublicKeys.map {
JWSAlgorithmFamilyJWSKeySelector.fromJWKSource<SecurityContext>(ImmutableJWKSet(JWKSet(it.toPublicJWK())))
} + jwkSources.map {
JWSAlgorithmFamilyJWSKeySelector.fromJWKSource<SecurityContext>(it)
}

processor.jwsKeySelector = if (keySelectors.size == 1) keySelectors.single() else CompositeJWSKeySelector(keySelectors)
Expand All @@ -91,13 +88,19 @@ class ModelixJWTUtil {
}.also { jwtProcessor = it }
}

private fun resetJwtProcess() {
fun getPrivateKey(): JWK? {
return jwkSources.flatMap {
it.get(JWKSelector(JWKMatcher.Builder().privateOnly(true).algorithms(JWSAlgorithm.Family.RSA.toSet()).build()), null)
}.firstOrNull()
}

private fun resetJwtProcessor() {
jwtProcessor = null
}

@Synchronized
fun canVerifyTokens(): Boolean {
return hmacKeys.isNotEmpty() || rsaPublicKeys.isNotEmpty() || jwksUrls.isNotEmpty()
return hmacKeys.isNotEmpty() || jwkSources.isNotEmpty()
}

/**
Expand All @@ -110,7 +113,7 @@ class ModelixJWTUtil {

@Synchronized
fun useKtorClient(client: HttpClient) {
resetJwtProcess()
resetJwtProcessor()
this.ktorClient = client.config {
expectSuccess = true
}
Expand All @@ -123,8 +126,8 @@ class ModelixJWTUtil {

@Synchronized
fun addJwksUrl(url: URL) {
resetJwtProcess()
jwksUrls += url
resetJwtProcessor()
jwkSources.add(RemoteJWKSet(url, ktorClient?.let { KtorResourceRetriever(it) }))
}

fun setHmac512Key(key: String) {
Expand All @@ -140,56 +143,67 @@ class ModelixJWTUtil {
fun addPublicKey(key: JWK) {
requireNotNull(key.keyID) { "Key doesn't specify a key ID: $key" }
requireNotNull(key.algorithm) { "Key doesn't specify an algorithm: $key" }
resetJwtProcess()
rsaPublicKeys.add(key)
resetJwtProcessor()
jwkSources.add(ImmutableJWKSet(JWKSet(key.toPublicJWK())))
}

@Synchronized
fun setRSAPrivateKey(key: JWK) {
requireNotNull(key.keyID) { "Key doesn't specify a key ID: $key" }
requireNotNull(key.algorithm) { "Key doesn't specify an algorithm: $key" }
resetJwtProcess()
this.rsaPrivateKey = key
addPublicKey(key.toPublicJWK())
resetJwtProcessor()
jwkSources.add(ImmutableJWKSet(JWKSet(listOf(key, key.toPublicJWK()))))
}

@Synchronized
private fun addHmacKey(key: ByteArray, algorithm: JWSAlgorithm) {
resetJwtProcess()
hmacKeys[algorithm] = key
fun addJWK(key: JWK) {
requireNotNull(key.keyID) { "Key doesn't specify a key ID: $key" }
requireNotNull(key.algorithm) { "Key doesn't specify an algorithm: $key" }
resetJwtProcessor()
if (key.isPrivate) {
jwkSources.add(ImmutableJWKSet(JWKSet(listOf(key, key.toPublicJWK()))))
} else {
jwkSources.add(ImmutableJWKSet(JWKSet(key)))
}
}

@Synchronized
fun getPublicJWKS(): JWKSet {
return JWKSet(listOfNotNull(rsaPrivateKey)).toPublicJWKSet()
private fun addHmacKey(key: ByteArray, algorithm: JWSAlgorithm) {
resetJwtProcessor()
hmacKeys[algorithm] = key
}

@Synchronized
fun loadKeysFromEnvironment() {
resetJwtProcess()
resetJwtProcessor()
System.getenv().filter { it.key.startsWith("MODELIX_JWK_FILE") }.values.forEach {
File(it).walk().forEach { file ->
when (file.extension) {
"pem" -> loadPemFile(file.readText())
"json" -> loadJwkFile(file.readText())
}
}
loadKeysFromFiles(File(it))
}

// allows multiple URLs (MODELIX_JWK_URI1, MODELIX_JWK_URI2, MODELIX_JWK_URI_MODEL_SERVER, ...)
System.getenv().filter { it.key.startsWith("MODELIX_JWK_URI") }.values
.forEach { addJwksUrl(URI(it).toURL()) }
}

fun loadKeysFromFiles(fileOrFolder: File) {
fileOrFolder.walk().forEach { file ->
when (file.extension) {
"pem" -> jwkSources.add(PemFileJWKSet(file))
"json" -> jwkSources.add(FileJWKSet(file))
}
}
}

@Synchronized
fun createAccessToken(user: String, grantedPermissions: List<String>, additionalTokenContent: (TokenBuilder) -> Unit = {}): String {
val signer: JWSSigner
val algorithm: JWSAlgorithm
val signingKeyId: String?
val jwk = this.rsaPrivateKey
val jwk = getPrivateKey()
if (jwk != null) {
signer = RSASSASigner(jwk.toRSAKey().toRSAPrivateKey())
algorithm = checkNotNull(jwk.algorithm) { "RSA key doesn't specify an algorithm" } as JWSAlgorithm
algorithm = checkNotNull(jwk.algorithm) { "RSA key doesn't specify an algorithm" }
.let { it as? JWSAlgorithm ?: JWSAlgorithm.parse(it.name) }
signingKeyId = checkNotNull(jwk.keyID) { "RSA key doesn't specify a key ID" }
} else {
val entry = checkNotNull(hmacKeys.entries.firstOrNull()) { "No keys for signing provided" }
Expand Down Expand Up @@ -273,11 +287,7 @@ class ModelixJWTUtil {
.issueTime(Date())
.algorithm(JWSAlgorithm.RS256)
.generate()
.also { setRSAPrivateKey(it) }
}

fun loadPemFile(fileContent: String): JWK {
return ensureValidKey(JWK.parseFromPEMEncodedObjects(fileContent)).also { loadJwk(it) }
.also { addJWK(it) }
}

private fun ensureValidKey(key: JWK): JWK {
Expand All @@ -302,19 +312,6 @@ class ModelixJWTUtil {
return RSAKey.Builder(rsaKey).keyID(keyId).build()
}

fun loadJwkFile(fileContent: String): JWK {
return JWK.parse(fileContent).also { loadJwk(it) }
}

private fun loadJwk(key: JWK) {
resetJwtProcess()
if (key.isPrivate) {
setRSAPrivateKey(key)
} else {
addPublicKey(key)
}
}

@Synchronized
fun verifyToken(token: String) {
getOrCreateJwtProcessor().process(JWTParser.parse(token), null)
Expand All @@ -327,6 +324,30 @@ class ModelixJWTUtil {
}
}

private inner class PemFileJWKSet<C : SecurityContext>(pemFile: File) : FileJWKSet<C>(pemFile) {
override fun readFile(): JWKSet {
return JWKSet(ensureValidKey(JWK.parseFromPEMEncodedObjects(file.readText())))
}
}

private open inner class FileJWKSet<C : SecurityContext>(val file: File) : JWKSource<C> {
private var loadedAt: Long = 0
private var cached: JWKSet? = null

open fun readFile(): JWKSet {
return JWKSet(JWK.parse(file.readText()))
}

override fun get(jwkSelector: JWKSelector, context: C?): List<JWK?>? {
val jwks = cached.takeIf { System.currentTimeMillis() - loadedAt < fileRefreshTime.inWholeMilliseconds }
?: readFile().also {
cached = it
loadedAt = System.currentTimeMillis()
}
return jwkSelector.select(jwks)
}
}

companion object {
fun extractUserId(jwt: DecodedJWT): String? {
return jwt.getClaim(KeycloakTokenConstants.EMAIL)?.asString()
Expand Down
Loading

0 comments on commit 2b80425

Please sign in to comment.