Skip to content

Commit

Permalink
2nd Update
Browse files Browse the repository at this point in the history
Dropped the SDK for requests library,
Added some tests
  • Loading branch information
iamnotcj committed Jan 25, 2025
1 parent 6c008b8 commit ecaa821
Show file tree
Hide file tree
Showing 7 changed files with 210 additions and 73 deletions.
163 changes: 93 additions & 70 deletions garak/generators/watsonx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,20 @@
from garak.generators.base import Generator
from typing import List, Union
import os
import importlib
import requests


class WatsonXGenerator(Generator):
"""
This is a generator for watsonx.ai.
Make sure that you initialize the environment variables:
'WATSONX_TOKEN',
'WATSONX_URL',
and 'WATSONX_PROJECTID'.
To use a tuned model that is deployed, use 'deployment/deployment' for the -n flag and make sure
to also initialize the 'WATSONX_DEPLOYID' environment variable.
Make sure that you initialize the environment variables:
'WATSONX_TOKEN',
'WATSONX_URL',
'WATSONX_PROJECTID' OR 'WATSONX_DEPLOYID'.
To use a model that is in the "project" stage initialize the WATSONX_PROJECTID variable with the Project ID of the model.
To use a tuned model that is deployed, simply initialize the WATSONX_DEPLOYID variable with the Deployment ID of the model.
"""

ENV_VAR = "WATSONX_TOKEN"
Expand All @@ -24,93 +24,110 @@ class WatsonXGenerator(Generator):
DID_ENV_VAR = "WATSONX_DEPLOYID"
DEFAULT_PARAMS = Generator.DEFAULT_PARAMS | {
"uri": None,
"project_id": None,
"deployment_id": None,
"frequency_penalty": 0.5,
"logprobs": True,
"top_logprobs": 3,
"presence_penalty": 0.3,
"temperature": 0.7,
"max_tokens": 100,
"time_limit": 300000,
"top_p": 0.9,
"n": 1,
"project_id": "",
"deployment_id": "",
"prompt_variable": "input",
"bearer_token": "",
"max_tokens": 900,
}

generator_family_name = "watsonx"

def __init__(self, name="", config_root=_config):
super().__init__(name, config_root=config_root)

# Initialize and validate api_key
if self.api_key is not None:
os.environ[self.ENV_VAR] = self.api_key


def _set_bearer_token(self, iam_url="https://iam.cloud.ibm.com/identity/token"):
header = {
"Content-Type": "application/x-www-form-urlencoded",
"Accept": "application/json",
}
body = (
"grant_type=urn:ibm:params:oauth:grant-type:apikey&apikey=" + self.api_key
)
response = requests.post(url=iam_url, headers=header, data=body)
self.bearer_token = "Bearer " + response.json()["access_token"]

def _generate_with_project(self, payload):
# Generation via Project ID.

url = self.uri + "/ml/v1/text/generation?version=2023-05-29"

body = {
"input": payload,
"parameters": {
"decoding_method": "greedy",
"max_new_tokens": self.max_tokens,
"min_new_tokens": 0,
"repetition_penalty": 1,
},
"model_id": self.name,
"project_id": self.project_id,
}

headers = {
"Accept": "application/json",
"Content-Type": "application/json",
"Authorization": self.bearer_token,
}

response = requests.post(url=url, headers=headers, json=body)
return response.json()

def _generate_with_deployment(self, payload):
# Generation via Deployment ID.
url = (
self.uri
+ "/ml/v1/deployments/"
+ self.deployment_id
+ "/text/generation?version=2021-05-01"
)
body = {"parameters": {"prompt_variables": {self.prompt_variable: payload}}}
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": self.bearer_token,
}
response = requests.post(url=url, headers=headers, json=body)
return response.json()

def _validate_env_var(self):
# Initialize and validate url.
if self.uri is not None:
pass
else :
else:
self.uri = os.getenv("WATSONX_URL", None)
if self.uri is None:
raise ValueError(
f"The {self.URI_ENV_VAR} environment variable is required. Please enter the URL corresponding to the region of your provisioned service instance. \n"
)

# Initialize and validate project_id.
if self.project_id is not None:
if self.project_id:
pass
else :
self.project_id = os.getenv("WATSONX_PROJECTID", None)
if self.project_id is None:
raise ValueError(
f"The {self.PID_ENV_VAR} environment variable is required. Please enter the corresponding Project ID of the resource. \n"
)

# Import Foundation Models from ibm_watsonx_ai module. Import the Credentials function from the same module.
self.watsonx = importlib.import_module("ibm_watsonx_ai.foundation_models")
self.Credentials = getattr(
importlib.import_module("ibm_watsonx_ai"), "Credentials"
)

def get_model(self):
# Call Credentials function with the url and api_key.
credentials = self.Credentials(url=self.uri, api_key=self.api_key)
if self.name == "deployment/deployment":
self.deployment_id = os.getenv("WATSONX_DEPLOYID", None)
if self.deployment_id is None:
raise ValueError(
f"The {self.DID_ENV_VAR} environment variable is required. Please enter the corresponding Deployment ID of the resource. \n"
)

return self.watsonx.ModelInference(
deployment_id=self.deployment_id,
credentials=credentials,
project_id=self.project_id,
)
else:
self.project_id = os.getenv("WATSONX_PROJECTID", "")

# Initialize and validate deployment_id.
if self.deployment_id:
pass
else:
return self.watsonx.ModelInference(
model_id=self.name,
credentials=credentials,
project_id=self.project_id,
params=self.watsonx.schema.TextChatParameters(
frequency_penalty=self.frequency_penalty,
logprobs=self.logprobs,
top_logprobs=self.top_logprobs,
presence_penalty=self.presence_penalty,
temperature=self.temperature,
max_tokens=self.max_tokens,
time_limit=self.time_limit,
top_p=self.top_p,
n=self.n,
),
self.deployment_id = os.getenv("WATSONX_DEPLOYID", "")

# Check to ensure at least ONE of project_id or deployment_id is populated.
if not self.project_id and not self.deployment_id:
raise ValueError(
f"Either {self.PID_ENV_VAR} or {self.DID_ENV_VAR} is required. Please supply either a Project ID or Deployment ID. \n"
)
return super()._validate_env_var()

def _call_model(
self, prompt: str, generations_this_call: int = 1
) -> List[Union[str, None]]:

# Get/Create Model
model = self.get_model()
if not self.bearer_token :
self._set_bearer_token()

# Check if message is empty. If it is, append null byte.
if not prompt:
Expand All @@ -119,8 +136,14 @@ def _call_model(
"WARNING: Empty prompt was found. Null byte character appended to prevent API failure."
)

output = ""
if self.deployment_id:
output = self._generate_with_deployment(prompt)
else:
output = self._generate_with_project(prompt)

# Parse the output to only contain the output message from the model. Return a list containing that message.
return ["".join(model.generate(prompt=prompt)["results"][0]["generated_text"])]
return ["".join(output["results"][0]["generated_text"])]


DEFAULT_CLASS = "WatsonXGenerator"
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ dependencies = [
"zalgolib>=0.2.2",
"ecoji>=0.1.1",
"deepl==1.17.0",
"ibm-watsonx-ai==1.1.25",
"fschat>=0.2.36",
"litellm>=1.41.21",
"jsonpath-ng>=1.6.1",
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ deepl==1.17.0
fschat>=0.2.36
litellm>=1.41.21
jsonpath-ng>=1.6.1
ibm-watsonx-ai==1.1.25
huggingface_hub>=0.21.0
python-magic-bin>=0.4.14; sys_platform == "win32"
python-magic>=0.4.21; sys_platform != "win32"
Expand Down
6 changes: 6 additions & 0 deletions tests/generators/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,9 @@ def hf_endpoint_mocks():
"""Mock responses for Huggingface InferenceAPI based endpoints"""
with open(pathlib.Path(__file__).parents[0] / "hf_inference.json") as mock_openai:
return json.load(mock_openai)

@pytest.fixture
def watsonx_compat_mocks():
"""Mock responses for watsonx.ai based endpoints"""
with open(pathlib.Path(__file__).parents[0] / "watsonx.json") as mock_watsonx:
return json.load(mock_watsonx)
2 changes: 1 addition & 1 deletion tests/generators/test_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def test_generator_structure(classname):
if classname
not in [
"generators.azure.AzureOpenAIGenerator", # requires additional env variables tested in own test class
"generators.watsonx.WatsonXGenerator", # requires additional env variables tested in own test class
"generators.function.Multiple", # requires mock local function not implemented here
"generators.function.Single", # requires mock local function not implemented here
"generators.ggml.GgmlGenerator", # validates files on disk tested in own test class
Expand Down Expand Up @@ -211,7 +212,6 @@ def test_instantiate_generators(classname):
"org_id": "fake", # required for NeMo
"uri": "https://example.com", # required for rest
"provider": "fake", # required for LiteLLM
"project_id": "fake", # required for watsonx
}
}
}
Expand Down
81 changes: 81 additions & 0 deletions tests/generators/test_watsonx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from garak.generators.watsonx import WatsonXGenerator
import os
import pytest
import requests_mock


DEFAULT_DEPLOYMENT_NAME = "ibm/granite-3-8b-instruct"


@pytest.fixture
def set_fake_env(request) -> None:
stored_env = {
WatsonXGenerator.ENV_VAR: os.getenv(WatsonXGenerator.ENV_VAR, None),
WatsonXGenerator.PID_ENV_VAR: os.getenv(WatsonXGenerator.PID_ENV_VAR, None),
WatsonXGenerator.URI_ENV_VAR: os.getenv(WatsonXGenerator.URI_ENV_VAR, None),
WatsonXGenerator.DID_ENV_VAR: os.getenv(WatsonXGenerator.DID_ENV_VAR, None),
}

def restore_env():
for k, v in stored_env.items():
if v is not None:
os.environ[k] = v
else:
del os.environ[k]

os.environ[WatsonXGenerator.ENV_VAR] = "XXXXXXXXXXXXX"
os.environ[WatsonXGenerator.PID_ENV_VAR] = "XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX"
os.environ[WatsonXGenerator.DID_ENV_VAR] = "XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX"
os.environ[WatsonXGenerator.URI_ENV_VAR] = "https://garak.example.com/"
request.addfinalizer(restore_env)


@pytest.mark.usefixtures("set_fake_env")
def test_bearer_token(watsonx_compat_mocks):
with requests_mock.Mocker() as m:
mock_response = watsonx_compat_mocks["watsonx_bearer_token"]

extended_request = "identity/token"

m.post(
"https://garak.example.com/" + extended_request, json=mock_response["json"]
)

granite_llm = WatsonXGenerator(DEFAULT_DEPLOYMENT_NAME)
token = granite_llm._set_bearer_token(iam_url="https://garak.example.com/identity/token")

assert granite_llm.bearer_token == ("Bearer " + mock_response["json"]["access_token"])


@pytest.mark.usefixtures("set_fake_env")
def test_project(watsonx_compat_mocks):
with requests_mock.Mocker() as m:
mock_response = watsonx_compat_mocks["watsonx_generation"]
extended_request = "/ml/v1/text/generation?version=2023-05-29"

m.post(
"https://garak.example.com/" + extended_request, json=mock_response["json"]
)

granite_llm = WatsonXGenerator(DEFAULT_DEPLOYMENT_NAME)
response = granite_llm._generate_with_project("What is this?")

assert granite_llm.name == response["model_id"]


@pytest.mark.usefixtures("set_fake_env")
def test_deployment(watsonx_compat_mocks):
with requests_mock.Mocker() as m:
mock_response = watsonx_compat_mocks["watsonx_generation"]
extended_request = "/ml/v1/deployments/"
extended_request += "XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX"
extended_request += "/text/generation?version=2021-05-01"

m.post(
"https://garak.example.com/" + extended_request, json=mock_response["json"]
)

granite_llm = WatsonXGenerator(DEFAULT_DEPLOYMENT_NAME)
response = granite_llm._generate_with_deployment("What is this?")

assert granite_llm.name == response["model_id"]
29 changes: 29 additions & 0 deletions tests/generators/watsonx.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
"watsonx_bearer_token": {
"code": 200,
"json": {
"access_token": "fake_token1231231231",
"refresh_token": "not_supported",
"token_type": "Bearer",
"expires_in": 3600,
"expiration": 1737754747,
"scope": "ibm openid"
}
},
"watsonx_generation": {
"code": 200,
"json" : {
"model_id": "ibm/granite-3-8b-instruct",
"model_version": "1.1.0",
"created_at": "2025-01-24T20:51:59.520Z",
"results": [
{
"generated_text": "This is a test generation. :)",
"generated_token_count": 32,
"input_token_count": 6,
"stop_reason": "eos_token"
}
]
}
}
}

0 comments on commit ecaa821

Please sign in to comment.