Skip to content

Commit

Permalink
Selenium Grid: Add trigger param to set custom capabilities for match…
Browse files Browse the repository at this point in the history
…ing specific Nodes (#6536)

Signed-off-by: Viet Nguyen Duc <[email protected]>
  • Loading branch information
VietND96 authored Feb 11, 2025
1 parent 914163c commit 9d1db63
Show file tree
Hide file tree
Showing 3 changed files with 262 additions and 38 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ Here is an overview of all new **experimental** features:
- **General**: Add SecretKey to AWS SecretsManager TriggerAuthentication to allow parsing JSON / Key/Value Pairs in secrets ([#5940](https://github.com/kedacore/keda/issues/5940))
- **IBMMQ Scaler**: Handling StatusNotFound in IBMMQ scaler ([#6472](https://github.com/kedacore/keda/pull/6472))
- **RabbitMQ Scaler**: Support use of the ‘vhostName’ parameter in the ‘TriggerAuthentication’ resource ([#6369](https://github.com/kedacore/keda/issues/6369))
- **Selenium Grid**: Add trigger param to set custom capabilities for matching specific Nodes ([#6536](https://github.com/kedacore/keda/issues/6536))


### Fixes
Expand Down
147 changes: 110 additions & 37 deletions pkg/scalers/selenium_grid_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ type seleniumGridScalerMetadata struct {
ActivationThreshold int64 `keda:"name=activationThreshold, order=triggerMetadata, optional"`
UnsafeSsl bool `keda:"name=unsafeSsl, order=triggerMetadata, default=false"`
NodeMaxSessions int64 `keda:"name=nodeMaxSessions, order=triggerMetadata, default=1"`
Capabilities string `keda:"name=capabilities, order=triggerMetadata, optional"`

TargetValue int64
}
Expand Down Expand Up @@ -95,15 +96,38 @@ type Slot struct {
Stereotype string `json:"stereotype"`
}

type Capability struct {
BrowserName string `json:"browserName,omitempty"`
BrowserVersion string `json:"browserVersion,omitempty"`
PlatformName string `json:"platformName,omitempty"`
type Stereotypes []struct {
Slots int64 `json:"slots"`
Stereotype map[string]interface{} `json:"stereotype"`
}

type Stereotypes []struct {
Slots int64 `json:"slots"`
Stereotype Capability `json:"stereotype"`
var ExtensionCapabilitiesPrefixes = []string{"goog:", "moz:", "ms:", "se:"}
var FunctionCapabilitiesPrefixes = []string{"se:downloadsEnabled"}

// Follow pattern in https://github.com/SeleniumHQ/selenium/blob/trunk/java/src/org/openqa/selenium/grid/data/DefaultSlotMatcher.java
func filterCapabilities(capabilities map[string]interface{}) map[string]interface{} {
filteredCapabilities := make(map[string]interface{})

for key, value := range capabilities {
retain := true
for _, excludePrefix := range ExtensionCapabilitiesPrefixes {
if strings.HasPrefix(key, excludePrefix) {
retain = false
break
}
}
for _, prefix := range FunctionCapabilitiesPrefixes {
if strings.HasPrefix(key, prefix) {
retain = true
break
}
}
if retain {
filteredCapabilities[key] = value
}
}

return filteredCapabilities
}

func NewSeleniumGridScaler(config *scalersconfig.ScalerConfig) (Scaler, error) {
Expand Down Expand Up @@ -207,38 +231,61 @@ func (s *seleniumGridScaler) getSessionsQueueLength(ctx context.Context, logger
}

if res.StatusCode != http.StatusOK {
msg := fmt.Sprintf("selenium grid returned %d", res.StatusCode)
msg := fmt.Sprintf("Selenium Grid returned response status code: %d", res.StatusCode)
logger.Error(errors.New(msg), msg)
return -1, -1, errors.New(msg)
}

defer res.Body.Close()
b, err := io.ReadAll(res.Body)
if err != nil {
logger.Error(err, fmt.Sprintf("Error when reading Selenium Grid response body: %s", err))
return -1, -1, err
}
newRequestNodes, onGoingSession, err := getCountFromSeleniumResponse(b, s.metadata.BrowserName, s.metadata.BrowserVersion, s.metadata.SessionBrowserName, s.metadata.PlatformName, s.metadata.NodeMaxSessions, logger)
newRequestNodes, onGoingSession, err := getCountFromSeleniumResponse(b, s.metadata.BrowserName, s.metadata.BrowserVersion, s.metadata.SessionBrowserName, s.metadata.PlatformName, s.metadata.NodeMaxSessions, s.metadata.Capabilities, logger)
if err != nil {
logger.Error(err, fmt.Sprintf("Error when getting count from Selenium Grid response: %s", err))
return -1, -1, err
}
return newRequestNodes, onGoingSession, nil
}

func countMatchingSlotsStereotypes(stereotypes Stereotypes, browserName string, browserVersion string, sessionBrowserName string, platformName string) int64 {
func getCapability(capability map[string]interface{}, key string) string {
value, ok := capability[key]
if ok {
return value.(string)
}
return ""
}

func getBrowserName(capability map[string]interface{}) string {
return getCapability(capability, "browserName")
}

func getBrowserVersion(capability map[string]interface{}) string {
return getCapability(capability, "browserVersion")
}

func getPlatformName(capability map[string]interface{}) string {
return getCapability(capability, "platformName")
}

func countMatchingSlotsStereotypes(stereotypes Stereotypes, browserName string, browserVersion string, sessionBrowserName string, platformName string, capabilities map[string]interface{}) int64 {
var matchingSlots int64
for _, stereotype := range stereotypes {
if checkStereotypeCapabilitiesMatch(stereotype.Stereotype, browserName, browserVersion, sessionBrowserName, platformName) {
if checkStereotypeCapabilitiesMatch(stereotype.Stereotype, browserName, browserVersion, sessionBrowserName, platformName, capabilities) {
matchingSlots += stereotype.Slots
}
}
return matchingSlots
}

func countMatchingSessions(sessions Sessions, browserName string, browserVersion string, sessionBrowserName string, platformName string, logger logr.Logger) int64 {
func countMatchingSessions(sessions Sessions, browserName string, browserVersion string, sessionBrowserName string, platformName string, capabilities map[string]interface{}, logger logr.Logger) int64 {
var matchingSessions int64
for _, session := range sessions {
var capability = Capability{}
var capability map[string]interface{}
if err := json.Unmarshal([]byte(session.Slot.Stereotype), &capability); err == nil {
if checkStereotypeCapabilitiesMatch(capability, browserName, browserVersion, sessionBrowserName, platformName) {
if checkStereotypeCapabilitiesMatch(capability, browserName, browserVersion, sessionBrowserName, platformName, capabilities) {
matchingSessions++
}
} else {
Expand All @@ -248,39 +295,58 @@ func countMatchingSessions(sessions Sessions, browserName string, browserVersion
return matchingSessions
}

func extensionCapabilitiesMatch(stereotype map[string]interface{}, capabilities map[string]interface{}) bool {
capabilities = filterCapabilities(capabilities)
if len(capabilities) == 0 {
return true
}
for key, value := range capabilities {
if stereotypeValue, ok := stereotype[key]; !ok || stereotypeValue != value {
return false
}
}
return true
}

// This function checks if the request capabilities match the scaler metadata
func checkRequestCapabilitiesMatch(request Capability, browserName string, browserVersion string, _ string, platformName string) bool {
func checkRequestCapabilitiesMatch(request map[string]interface{}, browserName string, browserVersion string, _ string, platformName string, capabilities map[string]interface{}) bool {
// Check if browserName matches
browserNameMatch := (request.BrowserName == "" && browserName == "") ||
strings.EqualFold(browserName, request.BrowserName)
_browserName := getBrowserName(request)
browserNameMatch := (_browserName == "" && browserName == "") ||
strings.EqualFold(browserName, _browserName)

// Check if browserVersion matches
browserVersionMatch := (request.BrowserVersion == "" && browserVersion == "") ||
(request.BrowserVersion != "" && strings.HasPrefix(browserVersion, request.BrowserVersion))
_browserVersion := getBrowserVersion(request)
browserVersionMatch := (_browserVersion == "" && browserVersion == "") ||
(_browserVersion != "" && strings.HasPrefix(browserVersion, _browserVersion))

// Check if platformName matches
platformNameMatch := (request.PlatformName == "" || strings.EqualFold("any", request.PlatformName) || strings.EqualFold(platformName, request.PlatformName)) &&
(platformName == "" || platformName == "any" || strings.EqualFold(platformName, request.PlatformName))
_platformName := getPlatformName(request)
platformNameMatch := (_platformName == "" || strings.EqualFold("any", _platformName) || strings.EqualFold(platformName, _platformName)) &&
(platformName == "" || platformName == "any" || strings.EqualFold(platformName, _platformName))

return browserNameMatch && browserVersionMatch && platformNameMatch
return browserNameMatch && browserVersionMatch && platformNameMatch && extensionCapabilitiesMatch(request, capabilities)
}

// This function checks if Node stereotypes or ongoing sessions match the scaler metadata
func checkStereotypeCapabilitiesMatch(capability Capability, browserName string, browserVersion string, sessionBrowserName string, platformName string) bool {
func checkStereotypeCapabilitiesMatch(capability map[string]interface{}, browserName string, browserVersion string, sessionBrowserName string, platformName string, capabilities map[string]interface{}) bool {
// Check if browserName matches
browserNameMatch := (capability.BrowserName == "" && browserName == "") ||
strings.EqualFold(browserName, capability.BrowserName) ||
strings.EqualFold(sessionBrowserName, capability.BrowserName)
_browserName := getBrowserName(capability)
browserNameMatch := (_browserName == "" && browserName == "") ||
strings.EqualFold(browserName, _browserName) ||
strings.EqualFold(sessionBrowserName, _browserName)

// Check if browserVersion matches
browserVersionMatch := (capability.BrowserVersion == "" && browserVersion == "") ||
(capability.BrowserVersion != "" && strings.HasPrefix(browserVersion, capability.BrowserVersion))
_browserVersion := getBrowserVersion(capability)
browserVersionMatch := (_browserVersion == "" && browserVersion == "") ||
(_browserVersion != "" && strings.HasPrefix(browserVersion, _browserVersion))

// Check if platformName matches
platformNameMatch := (capability.PlatformName == "" || strings.EqualFold("any", capability.PlatformName) || strings.EqualFold(platformName, capability.PlatformName)) &&
(platformName == "" || platformName == "any" || strings.EqualFold(platformName, capability.PlatformName))
_platformVersion := getPlatformName(capability)
platformNameMatch := (_platformVersion == "" || strings.EqualFold("any", _platformVersion) || strings.EqualFold(platformName, _platformVersion)) &&
(platformName == "" || platformName == "any" || strings.EqualFold(platformName, _platformVersion))

return browserNameMatch && browserVersionMatch && platformNameMatch
return browserNameMatch && browserVersionMatch && platformNameMatch && extensionCapabilitiesMatch(capability, capabilities)
}

func checkNodeReservedSlots(reservedNodes []ReservedNodes, nodeID string, availableSlots int64) int64 {
Expand All @@ -304,7 +370,7 @@ func updateOrAddReservedNode(reservedNodes []ReservedNodes, nodeID string, slotC
return append(reservedNodes, ReservedNodes{ID: nodeID, SlotCount: slotCount, MaxSession: maxSession})
}

func getCountFromSeleniumResponse(b []byte, browserName string, browserVersion string, sessionBrowserName string, platformName string, nodeMaxSessions int64, logger logr.Logger) (int64, int64, error) {
func getCountFromSeleniumResponse(b []byte, browserName string, browserVersion string, sessionBrowserName string, platformName string, nodeMaxSessions int64, _capabilities string, logger logr.Logger) (int64, int64, error) {
// Track number of available slots of existing Nodes in the Grid can be reserved for the matched requests
var availableSlots int64
// Track number of matched requests in the sessions queue will be served by this scaler
Expand All @@ -314,6 +380,13 @@ func getCountFromSeleniumResponse(b []byte, browserName string, browserVersion s
if err := json.Unmarshal(b, &seleniumResponse); err != nil {
return 0, 0, err
}
capabilities := map[string]interface{}{}
if _capabilities != "" {
if err := json.Unmarshal([]byte(_capabilities), &capabilities); err != nil {
logger.Error(err, fmt.Sprintf("Error when unmarshaling trigger metadata 'capabilities': %s", err))
return 0, 0, err
}
}

var sessionQueueRequests = seleniumResponse.Data.SessionsInfo.SessionQueueRequests
var nodes = seleniumResponse.Data.NodesInfo.Nodes
Expand All @@ -324,9 +397,9 @@ func getCountFromSeleniumResponse(b []byte, browserName string, browserVersion s
var onGoingSessions int64
for requestIndex, sessionQueueRequest := range sessionQueueRequests {
var isRequestMatched bool
var requestCapability = Capability{}
var requestCapability map[string]interface{}
if err := json.Unmarshal([]byte(sessionQueueRequest), &requestCapability); err == nil {
if checkRequestCapabilitiesMatch(requestCapability, browserName, browserVersion, sessionBrowserName, platformName) {
if checkRequestCapabilitiesMatch(requestCapability, browserName, browserVersion, sessionBrowserName, platformName, capabilities) {
queueSlots++
isRequestMatched = true
}
Expand All @@ -343,15 +416,15 @@ func getCountFromSeleniumResponse(b []byte, browserName string, browserVersion s
var availableSlotsMatch int64
if err := json.Unmarshal([]byte(node.Stereotypes), &stereotypes); err == nil {
// Count available slots that match the request capability and scaler metadata
availableSlotsMatch += countMatchingSlotsStereotypes(stereotypes, browserName, browserVersion, sessionBrowserName, platformName)
availableSlotsMatch += countMatchingSlotsStereotypes(stereotypes, browserName, browserVersion, sessionBrowserName, platformName, capabilities)
} else {
logger.Error(err, fmt.Sprintf("Error when unmarshaling node stereotypes: %s", err))
}
if availableSlotsMatch == 0 {
continue
}
// Count ongoing sessions that match the request capability and scaler metadata
var currentSessionsMatch = countMatchingSessions(node.Sessions, browserName, browserVersion, sessionBrowserName, platformName, logger)
var currentSessionsMatch = countMatchingSessions(node.Sessions, browserName, browserVersion, sessionBrowserName, platformName, capabilities, logger)
// Count remaining available slots can be reserved for this request
var availableSlotsCanBeReserved = checkNodeReservedSlots(reservedNodes, node.ID, node.MaxSession-node.SessionCount)
// Reserve one available slot for the request if available slots match is greater than current sessions match
Expand Down Expand Up @@ -381,7 +454,7 @@ func getCountFromSeleniumResponse(b []byte, browserName string, browserVersion s

// Count ongoing sessions across all nodes that match the scaler metadata
for _, node := range nodes {
onGoingSessions += countMatchingSessions(node.Sessions, browserName, browserVersion, sessionBrowserName, platformName, logger)
onGoingSessions += countMatchingSessions(node.Sessions, browserName, browserVersion, sessionBrowserName, platformName, capabilities, logger)
}

return int64(len(newRequestNodes)), onGoingSessions, nil
Expand Down
Loading

0 comments on commit 9d1db63

Please sign in to comment.