-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
834c7e9
commit 753c9e1
Showing
6 changed files
with
1,099 additions
and
81 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
from .core import RiskModel, scan_llm, scan_prompt | ||
from .core import RiskModel, Threat, scan_llm, scan_prompt | ||
|
||
__all__ = ["scan_llm", "scan_prompt", "RiskModel"] | ||
__all__ = ["scan_llm", "scan_prompt", "RiskModel", "Threat"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .llm import LastLayerSecurity | ||
|
||
__all__ = ["LastLayerSecurity"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import logging | ||
from typing import Any, Callable, List | ||
|
||
import last_layer | ||
|
||
from langchain_core.language_models.llms import BaseLLM | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def default_handler(text, risk: last_layer.RiskModel): | ||
if risk.passed: | ||
return | ||
logger.warning(f"Security risk: {risk} detected in text: {text}") | ||
|
||
|
||
class LastLayerSecurity(BaseLLM): | ||
llm: BaseLLM | ||
handle_prompt_risk: Callable[str, last_layer.RiskModel] = default_handler | ||
handle_response_risk: Callable[str, last_layer.RiskModel] = default_handler | ||
ignore_opts: list[last_layer.Threat] = [] | ||
|
||
@property | ||
def _llm_type(self) -> str: | ||
return "LastLayerSecurity" | ||
|
||
def _generate( | ||
self, | ||
prompts: List[str], | ||
**kwargs: Any, | ||
) -> Any: | ||
"""Run the LLM on the given prompts.""" | ||
|
||
for prompt in prompts: | ||
risk = last_layer.scan_prompt(prompt) | ||
self.handle_prompt_risk(prompt, risk) | ||
result = self.llm._generate(prompts, **kwargs) | ||
|
||
for top_gen in result.generations: | ||
for gen in top_gen: | ||
risk = last_layer.scan_llm(gen.text) | ||
self.handle_response_risk(gen.text, risk) | ||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import unittest | ||
from .llm import LastLayerSecurity | ||
|
||
|
||
class TestScanPrompt(unittest.TestCase): | ||
@unittest.skip("Not implemented") | ||
def test_integration(self): | ||
# The line `from langchain_openai import OpenAI` is importing the `OpenAI` class from the | ||
# `langchain_openai` module. This allows the code to use the `OpenAI` class and its | ||
# functionalities within the current module or script. | ||
from langchain_contrib.llms.testing import FakeLLM | ||
|
||
secure_llm = LastLayerSecurity( | ||
llm=FakeLLM(verbose=True, sequenced_responses=["One", "Two", "Three"]) | ||
) | ||
response = secure_llm.invoke( | ||
"Summarize this message: my name is Bob Dylan. My SSN is 123-45-6789." | ||
) | ||
print(f"{response=}") |
Oops, something went wrong.