diff --git a/security/jwt/src/main/java/io/helidon/security/jwt/Jwt.java b/security/jwt/src/main/java/io/helidon/security/jwt/Jwt.java index 0c7b914a448..2d1380bb69f 100644 --- a/security/jwt/src/main/java/io/helidon/security/jwt/Jwt.java +++ b/security/jwt/src/main/java/io/helidon/security/jwt/Jwt.java @@ -515,7 +515,6 @@ public static void addAudienceValidator(Collection> validators, S * @param validators collection of validators * @param audience audience expected to be in the token * @param mandatory whether the audience field is mandatory in the token - * @param scope jwt scope */ public static void addAudienceClaimValidator(Collection validators, Set audience, boolean mandatory) { validators.add(new AudienceValidator(audience, mandatory)); @@ -1014,8 +1013,8 @@ public JsonObject payloadJson() { /** * Validate this JWT against provided validators. - * - * This method does not work properly upon validation of the crit JWT header. + *

+ * This method does not work properly upon validation of the {@code crit} JWT header. * * @param validators Validators to validate with. Obtain them through (e.g.) {@link #defaultTimeValidators()} * , {@link #addAudienceValidator(Collection, String, boolean)} @@ -1023,7 +1022,7 @@ public JsonObject payloadJson() { * @return errors instance to check if valid and access error messages * @deprecated use {{@link #validateClaims(List)}} method instead */ - @Deprecated(since = "4.0.11") + @Deprecated(since = "4.1.0", forRemoval = true) public Errors validate(List> validators) { Errors.Collector collector = Errors.collector(); validators.forEach(it -> it.validate(this, collector)); @@ -2251,29 +2250,29 @@ public void validate(Jwt jwt, Errors.Collector collector, List v private record AudienceValidator(Set expectedAudience, boolean mandatory) implements ClaimValidator { - @Override - public JwtScope jwtScope() { - return JwtScope.PAYLOAD; - } + @Override + public JwtScope jwtScope() { + return JwtScope.PAYLOAD; + } - @Override - public Set claims() { - return Set.of(AUDIENCE); - } + @Override + public Set claims() { + return Set.of(AUDIENCE); + } - @Override - public void validate(Jwt jwt, Errors.Collector collector, List validators) { - Optional> jwtAudiences = jwt.audience(); - if (jwtAudiences.isPresent()) { - if (expectedAudience.stream().anyMatch(jwtAudiences.get()::contains)) { - return; - } - collector.fatal(jwt, "Audience must contain " + expectedAudience + ", yet it is: " + jwtAudiences); - } else { - if (mandatory) { - collector.fatal(jwt, "Audience is expected to be: " + expectedAudience + ", yet no audience in JWT"); - } + @Override + public void validate(Jwt jwt, Errors.Collector collector, List validators) { + Optional> jwtAudiences = jwt.audience(); + if (jwtAudiences.isPresent()) { + if (expectedAudience.stream().anyMatch(jwtAudiences.get()::contains)) { + return; + } + collector.fatal(jwt, "Audience must contain " + expectedAudience + ", yet it is: " + jwtAudiences); + } else { + if (mandatory) { + collector.fatal(jwt, "Audience is expected to be: " + expectedAudience + ", yet no audience in JWT"); } } } + } } diff --git a/security/jwt/src/main/java/io/helidon/security/jwt/JwtValidator.java b/security/jwt/src/main/java/io/helidon/security/jwt/JwtValidator.java new file mode 100644 index 00000000000..be41eb81c4c --- /dev/null +++ b/security/jwt/src/main/java/io/helidon/security/jwt/JwtValidator.java @@ -0,0 +1,732 @@ +package io.helidon.security.jwt; + +import java.time.Duration; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.time.temporal.TemporalUnit; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.stream.Collectors; + +import io.helidon.common.Errors; + +import jakarta.json.JsonString; + +public class JwtValidator { + + private final List claimValidators; + + private JwtValidator(Builder builder) { + claimValidators = List.copyOf(builder.claimValidators); + } + + public static Builder builder() { + return new Builder(); + } + + public Errors validate(Jwt jwt) { + Errors.Collector collector = Errors.collector(); + claimValidators.forEach(claimValidator -> claimValidator.validate(jwt, collector, claimValidators)); + return collector.collect(); + } + + public static final class Builder implements io.helidon.common.Builder { + + private final List claimValidators = new ArrayList<>(); + + private Builder() { + } + + @Override + public JwtValidator build() { + return new JwtValidator(this); + } + + public Builder addClaimValidator(ClaimValidator claimValidator) { + claimValidators.add(claimValidator); + return this; + } + + public Builder addHeaderValidator(String claim, Validator validator) { + + } + + public Builder addPayloadValidator(String claim, Validator validator) { + + } + + public Builder addDefaultTimeValidators() { + claimValidators.add(new ExpirationValidator(false)); + claimValidators.add(new IssueTimeValidator()); + claimValidators.add(new NotBeforeValidator()); + return this; + } + + public Builder addDefaultTimeValidators(Instant now, Duration allowedTimeSkew, boolean mandatory) { + claimValidators.add(new ExpirationValidator(now, allowedTimeSkew, mandatory)); + claimValidators.add(new IssueTimeValidator(now, allowedTimeSkew, mandatory)); + claimValidators.add(new NotBeforeValidator(now, allowedTimeSkew, mandatory)); + return this; + } + + public Builder addExpirationValidator() { + claimValidators.add(new ExpirationValidator.Builder().build()); + return this; + } + + public Builder addExpirationValidator(Consumer builderConsumer) { + ExpirationValidator.Builder builder = new ExpirationValidator.Builder(); + builderConsumer.accept(builder); + claimValidators.add(builder.build()); + return this; + } + + public Builder addExpirationValidator(boolean mandatory) { + claimValidators.add(new ExpirationValidator(mandatory)); + return this; + } + + public Builder addExpirationValidator(Instant now, Duration allowedTimeSkew, boolean mandatory) { + claimValidators.add(new ExpirationValidator(now, allowedTimeSkew, mandatory)); + return this; + } + + public Builder addNotBeforeValidator() { + claimValidators.add(new NotBeforeValidator()); + return this; + } + + public Builder addNotBeforeValidator(Instant now, Duration allowedTimeSkew, boolean mandatory) { + claimValidators.add(new NotBeforeValidator(now, allowedTimeSkew, mandatory)); + return this; + } + + public Builder addIssueTimeValidator() { + claimValidators.add(new IssueTimeValidator()); + return this; + } + + public Builder addIssueTimeValidator(Instant now, Duration allowedTimeSkew, boolean mandatory) { + claimValidators.add(new IssueTimeValidator(now, allowedTimeSkew, mandatory)); + return this; + } + + public Builder addAudienceValidator(String expectedAudience, boolean mandatory) { + claimValidators.add(new AudienceValidator(Set.of(expectedAudience), mandatory)); + return this; + } + + public Builder addAudienceValidator(Set expectedAudience, boolean mandatory) { + claimValidators.add(new AudienceValidator(Set.copyOf(expectedAudience), mandatory)); + return this; + } + + public Builder addCriticalValidator() { + claimValidators.add(new CriticalValidator()); + return this; + } + + public Builder clearValidators() { + claimValidators.clear(); + return this; + } + + } + + private abstract static class OptionalValidator { + private final boolean mandatory; + + OptionalValidator() { + this.mandatory = false; + } + + OptionalValidator(boolean mandatory) { + this.mandatory = mandatory; + } + + Optional validate(String name, Optional optional, Errors.Collector collector) { + if (mandatory && optional.isEmpty()) { + collector.fatal("Field " + name + " is mandatory, yet not defined in JWT"); + } + return optional; + } + } + + private abstract static class InstantValidator extends OptionalValidator { + private final Instant now; + private final Duration allowedTimeSkew; + + private InstantValidator(BaseBuilder builder) { + super(builder.mandatory); + this.now = builder.now; + this.allowedTimeSkew = builder.allowedTimeSkew; + } + + Instant latest() { + return instant().plus(allowedTimeSkew); + } + + Instant earliest() { + return instant().minus(allowedTimeSkew); + } + + Instant instant() { + return now == null ? Instant.now() : now; + } + + private static abstract class BaseBuilder, T> + implements io.helidon.common.Builder, T> { + + private boolean mandatory = false; + private Instant now = null; + private Duration allowedTimeSkew = Duration.ofSeconds(5); + + private BaseBuilder() { + } + + public B mandatory(boolean mandatory) { + this.mandatory = mandatory; + return me(); + } + + public B now(Instant now) { + this.now = now; + return me(); + } + + public B allowedTimeSkew(Duration allowedTimeSkew) { + this.allowedTimeSkew = allowedTimeSkew; + return me(); + } + + @SuppressWarnings("unchecked") + protected B me() { + return (B) this; + } + + } + } + + + /** + * Validator of a string field obtained from a + */ + public static final class FieldValidator extends OptionalValidator implements ClaimValidator { + private final Function> fieldAccessor; + private final JwtScope jwtScope; + private final String expectedValue; + private final String fieldName; + + private FieldValidator(Function> fieldAccessor, + String fieldName, + String expectedValue, + boolean mandatory, + JwtScope jwtScope) { + super(mandatory); + this.fieldAccessor = fieldAccessor; + this.fieldName = fieldName; + this.expectedValue = expectedValue; + this.jwtScope = jwtScope; + } + + /** + * A generic optional field validator based on a function to get the field. + * + * @param fieldAccessor function to extract field from JWT + * @param name descriptive name of the field + * @param expectedValue value to expect + * @return validator instance + * @deprecated create field validator for jwt payload. + * Please use {{@link #create(Function, String, String, JwtScope)}} + */ + @Deprecated(since = "4.0.11") + public static FieldValidator create(Function> fieldAccessor, + String name, + String expectedValue) { + return create(fieldAccessor, name, expectedValue, false, JwtScope.PAYLOAD); + } + + /** + * A generic optional field validator based on a function to get the field. + * + * @param fieldAccessor function to extract field from JWT + * @param name descriptive name of the field + * @param expectedValue value to expect + * @param jwtScope jwt scope + * @return validator instance + */ + public static FieldValidator create(Function> fieldAccessor, + String name, + String expectedValue, + JwtScope jwtScope) { + return create(fieldAccessor, name, expectedValue, false, jwtScope); + } + + /** + * A generic field validator based on a function to get the field. + * + * @param fieldAccessor function to extract field from JWT + * @param name descriptive name of the field + * @param expectedValue value to expect + * @param mandatory true for mandatory, false for optional + * @return validator instance + * @deprecated create field validator for jwt payload. + * Please use {{@link #create(Function, String, String, boolean, JwtScope)}} + */ + @Deprecated(since = "4.0.11") + public static FieldValidator create(Function> fieldAccessor, + String name, + String expectedValue, + boolean mandatory) { + return new FieldValidator(fieldAccessor, name, expectedValue, mandatory, JwtScope.PAYLOAD); + } + + + /** + * A generic field validator based on a function to get the field. + * + * @param fieldAccessor function to extract field from JWT + * @param name descriptive name of the field + * @param expectedValue value to expect + * @param mandatory true for mandatory, false for optional + * @param jwtScope jwt scope + * @return validator instance + */ + public static FieldValidator create(Function> fieldAccessor, + String name, + String expectedValue, + boolean mandatory, + JwtScope jwtScope) { + return new FieldValidator(fieldAccessor, name, expectedValue, mandatory, jwtScope); + } + + /** + * An optional header field validator. + * + * @param fieldKey name of the header claim + * @param name descriptive name of the field + * @param expectedValue value to expect + * @return validator instance + */ + public static FieldValidator createForHeader(String fieldKey, + String name, + String expectedValue) { + + return createForHeader(fieldKey, name, expectedValue, false); + } + + /** + * A header field validator. + * + * @param fieldKey name of the header claim + * @param name descriptive name of the field + * @param expectedValue value to expect + * @param mandatory whether the field is mandatory or optional + * @return validator instance + */ + public static FieldValidator createForHeader(String fieldKey, + String name, + String expectedValue, + boolean mandatory) { + + return create(jwt -> jwt.headerClaim(fieldKey).map(it -> ((JsonString) it).getString()), + name, + expectedValue, + mandatory, + JwtScope.HEADER); + } + + /** + * An optional payload field validator. + * + * @param fieldKey name of the payload claim + * @param name descriptive name of the field + * @param expectedValue value to expect + * @return validator instance + */ + public static FieldValidator createForPayload(String fieldKey, + String name, + String expectedValue) { + + return createForPayload(fieldKey, name, expectedValue, false); + } + + /** + * A payload field validator. + * + * @param fieldKey name of the payload claim + * @param name descriptive name of the field + * @param expectedValue value to expect + * @param mandatory whether the field is mandatory or optional + * @return validator instance + */ + public static FieldValidator createForPayload(String fieldKey, + String name, + String expectedValue, + boolean mandatory) { + return create(jwt -> jwt.payloadClaim(fieldKey).map(it -> ((JsonString) it).getString()), + name, + expectedValue, + false, + JwtScope.PAYLOAD); + } + + @Override + public JwtScope jwtScope() { + return jwtScope; + } + + @Override + public Set claims() { + return Set.of(fieldName); + } + + @Override + public void validate(Jwt token, Errors.Collector collector, List validators) { + super.validate(fieldName, fieldAccessor.apply(token), collector) + .ifPresent(it -> { + if (!expectedValue.equals(it)) { + collector.fatal(token, + "Expected value of field \"" + fieldName + "\" was \"" + expectedValue + "\", but " + + "actual value is: \"" + it); + } + }); + } + } + + /** + * Validator of issue time claim. + */ + public static final class IssueTimeValidator extends InstantValidator implements ClaimValidator { + + private IssueTimeValidator() { + } + + private IssueTimeValidator(Instant now, Duration allowedTimeSkew, boolean mandatory) { + super(now, allowedTimeSkew, mandatory); + } + + + /** + * New instance with default values (allowed time skew 5 seconds, optional). + * + * @return issue time validator with defaults + */ + public static IssueTimeValidator create() { + return new IssueTimeValidator(); + } + + /** + * New instance with explicit values. + * + * @param now time to validate against (to be able to validate past tokens) + * @param allowedTimeSkew allowed time skew amount (such as 5) + * @param allowedTimeSkewUnit allowed time skew unit (such as {@link ChronoUnit#SECONDS} + * @param mandatory true for mandatory, false for optional + * @return configured issue time validator + */ + public static IssueTimeValidator create(Instant now, + int allowedTimeSkew, + TemporalUnit allowedTimeSkewUnit, + boolean mandatory) { + return new IssueTimeValidator(now, allowedTimeSkew, allowedTimeSkewUnit, mandatory); + } + + @Override + public JwtScope jwtScope() { + return JwtScope.PAYLOAD; + } + + @Override + public Set claims() { + return Set.of(Jwt.ISSUED_AT); + } + + @Override + public void validate(Jwt token, Errors.Collector collector, List validators) { + Optional issueTime = token.issueTime(); + issueTime.ifPresent(it -> { + // must be issued in the past + if (latest().isBefore(it)) { + collector.fatal(token, "Token was not issued in the past: " + it); + } + }); + // ensure we fail if mandatory and not present + super.validate("issueTime", issueTime, collector); + } + } + + /** + * Validator of expiration claim. + */ + public static final class ExpirationValidator extends InstantValidator implements ClaimValidator { +// private ExpirationValidator(boolean mandatory) { +// super(mandatory); +// } +// +// private ExpirationValidator(Instant now, Duration allowedTimeSkew, boolean mandatory) { +// super(now, allowedTimeSkew, mandatory); +// } +// +// /** +// * New instance with default values (allowed time skew 5 seconds, optional). +// * +// * @return expiration time validator with defaults +// */ +// public static ExpirationValidator create() { +// return new ExpirationValidator(false); +// } +// +// /** +// * New instance with default values (allowed time skew 5 seconds). +// * +// * @param mandatory if this value is mandatory or not +// * @return expiration time validator with defaults +// */ +// public static ExpirationValidator create(boolean mandatory) { +// return new ExpirationValidator(mandatory); +// } +// +// /** +// * New instance with explicit values. +// * +// * @param now time to validate against (to be able to validate past tokens) +// * @param allowedTimeSkew allowed time skew amount (such as 5) +// * @param allowedTimeSkewUnit allowed time skew unit (such as {@link ChronoUnit#SECONDS} +// * @param mandatory true for mandatory, false for optional +// * @return expiration time validator +// */ +// public static ExpirationValidator create(Instant now, +// int allowedTimeSkew, +// TemporalUnit allowedTimeSkewUnit, +// boolean mandatory) { +// return new ExpirationValidator(now, allowedTimeSkew, allowedTimeSkewUnit, mandatory); +// } + + private ExpirationValidator(Builder builder) { + super(builder); + } + + @Override + public JwtScope jwtScope() { + return JwtScope.PAYLOAD; + } + + @Override + public Set claims() { + return Set.of(Jwt.EXPIRATION, Jwt.ISSUED_AT); + } + + @Override + public void validate(Jwt token, Errors.Collector collector, List validators) { + Optional expirationTime = token.expirationTime(); + expirationTime.ifPresent(it -> { + if (earliest().isAfter(it)) { + collector.fatal(token, "Token no longer valid, expiration: " + it); + } + token.issueTime().ifPresent(issued -> { + if (issued.isAfter(it)) { + collector.fatal(token, "Token issue date is after its expiration, " + + "issue: " + it + ", expiration: " + it); + } + }); + }); + // ensure we fail if mandatory and not present + super.validate("expirationTime", expirationTime, collector); + } + + public static final class Builder extends BaseBuilder { + @Override + public ExpirationValidator build() { + return new ExpirationValidator(this); + } + } + } + + /** + * Validator of not before claim. + */ + public static final class NotBeforeValidator extends InstantValidator implements ClaimValidator { + private NotBeforeValidator() { + } + + private NotBeforeValidator(Instant now, Duration allowedTimeSkew, boolean mandatory) { + super(now, allowedTimeSkew, mandatory); + } + + /** + * New instance with default values (allowed time skew 5 seconds, optional). + * + * @return not before time validator with defaults + */ + public static NotBeforeValidator create() { + return new NotBeforeValidator(); + } + + /** + * New instance with explicit values. + * + * @param now time to validate against (to be able to validate past tokens) + * @param allowedTimeSkew allowed time skew amount (such as 5) + * @param allowedTimeSkewUnit allowed time skew unit (such as {@link ChronoUnit#SECONDS} + * @param mandatory true for mandatory, false for optional + * @return not before time validator + */ + public static NotBeforeValidator create(Instant now, + int allowedTimeSkew, + TemporalUnit allowedTimeSkewUnit, + boolean mandatory) { + return new NotBeforeValidator(now, allowedTimeSkew, allowedTimeSkewUnit, mandatory); + } + + @Override + public JwtScope jwtScope() { + return JwtScope.PAYLOAD; + } + + @Override + public Set claims() { + return Set.of(Jwt.NOT_BEFORE); + } + + @Override + public void validate(Jwt token, Errors.Collector collector, List validators) { + Optional notBefore = token.notBefore(); + notBefore.ifPresent(it -> { + if (latest().isBefore(it)) { + collector.fatal(token, "Token not yet valid, not before: " + it); + } + }); + // ensure we fail if mandatory and not present + super.validate("notBefore", notBefore, collector); + } + } + + private static final class UserPrincipalValidator extends OptionalValidator implements ClaimValidator { + + private UserPrincipalValidator() { + super(true); + } + + @Override + public JwtScope jwtScope() { + return JwtScope.PAYLOAD; + } + + @Override + public Set claims() { + return Set.of(Jwt.USER_PRINCIPAL_NAME); + } + + @Override + public void validate(Jwt object, Errors.Collector collector, List validators) { + super.validate("User Principal", object.userPrincipal(), collector); + } + + } + + private static final class CriticalValidator implements ClaimValidator { + + @Override + public JwtScope jwtScope() { + return JwtScope.HEADER; + } + + @Override + public Set claims() { + return Set.of(Jwt.CRITICAL); + } + + @Override + public void validate(Jwt jwt, Errors.Collector collector, List validators) { + Optional> maybeCritical = jwt.headers().critical(); + if (maybeCritical.isPresent()) { + List critical = maybeCritical.get(); + if (critical.isEmpty()) { + collector.fatal(jwt, "JWT critical header must not be empty"); + return; + } + Set headerClaims = jwt.headers().headerClaims().keySet(); + boolean containsAllCritical = headerClaims.containsAll(critical); + if (!containsAllCritical) { + collector.fatal(jwt, "JWT must contain " + critical + ", yet it contains: " + headerClaims); + } + Set supportedHeaderClaims = validators + .stream() + .filter(claimValidator -> claimValidator.jwtScope() == JwtScope.HEADER) + .map(ClaimValidator::claims) + .flatMap(Set::stream) + .collect(Collectors.toSet()); + containsAllCritical = supportedHeaderClaims.containsAll(critical); + if (!containsAllCritical) { + collector.fatal(jwt, "JWT is required to process " + critical + + ", yet it process only " + supportedHeaderClaims); + } + } + } + } + + private record MaxTokenAgeValidator(Duration expectedMaxTokenAge, Duration clockSkew, boolean iatRequired) + implements ClaimValidator { + + @Override + public JwtScope jwtScope() { + return JwtScope.PAYLOAD; + } + + @Override + public Set claims() { + return Set.of(Jwt.ISSUED_AT); + } + + @Override + public void validate(Jwt jwt, Errors.Collector collector, List validators) { + Optional maybeIssueTime = jwt.issueTime(); + if (maybeIssueTime.isPresent()) { + Instant now = Instant.now(); + Instant issueTime = maybeIssueTime.get().minus(clockSkew); + Instant maxValidTime = issueTime.plus(expectedMaxTokenAge).plus(clockSkew); + if (issueTime.isBefore(now) && maxValidTime.isAfter(now)) { + return; + } + collector.fatal(jwt, "Current time need to be between " + issueTime + + " and " + maxValidTime + ", but was " + now); + } else if (iatRequired) { + collector.fatal(jwt, "Claim iat is required to be present in JWT when validating token max allowed age."); + } + } + } + + private record AudienceValidator(Set expectedAudience, boolean mandatory) implements ClaimValidator { + + @Override + public JwtScope jwtScope() { + return JwtScope.PAYLOAD; + } + + @Override + public Set claims() { + return Set.of(Jwt.AUDIENCE); + } + + @Override + public void validate(Jwt jwt, Errors.Collector collector, List validators) { + Optional> jwtAudiences = jwt.audience(); + if (jwtAudiences.isPresent()) { + if (expectedAudience.stream().anyMatch(jwtAudiences.get()::contains)) { + return; + } + collector.fatal(jwt, "Audience must contain " + expectedAudience + ", yet it is: " + jwtAudiences); + } else { + if (mandatory) { + collector.fatal(jwt, "Audience is expected to be: " + expectedAudience + ", yet no audience in JWT"); + } + } + } + } + +}