Skip to content

Commit

Permalink
Update llm_integration.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Cybonto committed Dec 13, 2024
1 parent 8fc0f86 commit 0ca5f88
Showing 1 changed file with 103 additions and 36 deletions.
139 changes: 103 additions & 36 deletions streamlit_app/app/entity_bridge/llm_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@
# Note: Users must install the required SDKs for each provider they intend to use.

# Ollama
import ollama
from ollama import Client
try:
import ollama
from ollama import Client as OllamaClient
import requests
except ImportError:
OllamaClient = None

# OpenAI
try:
Expand All @@ -35,21 +39,26 @@
# Google Vertex AI
try:
from google.cloud import aiplatform
from google.oauth2 import service_account
from vertexai.preview.language_models import ChatModel
except ImportError:
aiplatform = None
service_account = None
ChatModel = NotImplementedError

# AWS Bedrock (Note: AWS SDK supports Bedrock in later versions)
try:
import boto3
except ImportError:
boto3 = None

def setup_llm_client(provider: str, **credentials) -> Any:
def setup_llm_client(provider: str, selected_model: str = None, **credentials) -> Any:
"""
Set up the LLM client based on the selected provider and credentials.
Args:
provider (str): The name of the LLM provider ('ollama', 'openai', 'anthropic', 'vertexai', 'bedrock').
selected_model (str): The name of the LLM model to use.
**credentials: Keyword arguments containing necessary credentials.
Returns:
Expand All @@ -64,7 +73,7 @@ def setup_llm_client(provider: str, **credentials) -> Any:
api_key = credentials.get('api_key')
if not api_key:
raise ValueError("API key is required for OpenAI.")
openai_client = OpenAI(api_key=)api_key
openai_client = OpenAI(api_key=api_key)
return openai_client
elif provider.lower() == 'anthropic':
if anthropic is None:
Expand All @@ -78,23 +87,46 @@ def setup_llm_client(provider: str, **credentials) -> Any:
if aiplatform is None:
raise ImportError("Google Cloud AI Platform library is not installed. Please install it with 'pip install google-cloud-aiplatform'.")
# For Vertex AI, authentication is typically handled via environment variables or service accounts.
aiplatform.init()
project_id = credentials.get('project_id')
location = credentials.get('location','us-central')
credentials_info = credentials.get('credentials_info')
if not project_id:
raise ValueError("Project ID is required for Google Vertex")
if credentials_info:
credentials_obj = service_account.Credentials.from_service_account_info(credentials_info)
else:
credentials_obj = None #assuming default credential
aiplatform.init(project=project_id, location=location, credentials=credentials_obj)
return aiplatform
elif provider.lower() == 'bedrock':
if boto3 is None:
raise ImportError("Boto3 library is not installed. Please install it with 'pip install boto3'.")
# AWS credentials are typically handled via environment variables or configuration files.
client = boto3.client('bedrock')
aws_access_key_id = credentials.get('aws_access_key_id')
aws_secret_access_key = credentials.get('aws_secret_access_key')
aws_session_token = credentials.get('aws_session_token')
region_name = credentials.get('region_name', 'us-east-1')
if not aws_access_key_id or not aws_secret_access_key:
raise ValueError("AWS credentials are required for Amazon Bedrock.")
session = boto3.Session(
aws_access_key_id = aws_access_key_id,
aws_secret_access_key = aws_secret_access_key,
aws_session_token = aws_session_token,
region_name = region_name
)
client = boto3.client('bedrock-runtime') #or "bedrock"
return client
elif provider.lower() == 'ollama':
if OllamaClient is None:
raise ImportError("Ollama client library is not installed. Please install it using 'pip install -U ollama.'")
base_url = credentials.get('base_url', 'http://localhost:11434')
client = Client(host=base_url)
client = OllamaClient(host=base_url)
return client
else:
raise ValueError(f"Unsupported provider: {provider}")


def generate_entity_mappings_with_llm(prompt: str, client: Any, provider: str, model_name: str) -> Dict[str, Any]:
def generate_entity_mappings_with_llm(prompt: str, client: Any, provider: str, model_name: str, **kwargs) -> Dict[str, Any]:
"""
Generate entity mappings using the provided LLM client.
Expand All @@ -111,58 +143,85 @@ def generate_entity_mappings_with_llm(prompt: str, client: Any, provider: str, m
Exception: If the LLM generation fails.
"""
try:
generated_text = ""
if provider.lower() == 'openai':
# For OpenAI, using ChatCompletion API
response = client.ChatCompletion.create(
model=model_name,
messages=[{"role": "system", "content": prompt}],
max_tokens=500,
n=1,
stop=None,
temperature=0.7,
#n=1,
#stop=None,
temperature=0.1,
)
generated_text = response.choices[0].message['content'].strip()
# Parse the generated_text into a dictionary as needed
entity_mappings = parse_llm_output(generated_text)
return entity_mappings
elif provider.lower() == 'anthropic':
# For Anthropic's Claude API
response = client.completion(
prompt=prompt,
model=model_name,
max_tokens_to_sample=500,
stop_sequences=[],
temperature=0.7,
temperature=0.1,
)
generated_text = response.get('completion', '').strip()
entity_mappings = parse_llm_output(generated_text)
return entity_mappings
elif provider.lower() == 'vertexai':
# For Google Vertex AI
# Assuming use of Text Generation models
model = client.TextGenerationModel.from_pretrained(model_name)
response = model.predict(prompt)
generated_text = response.text.strip()
entity_mappings = parse_llm_output(generated_text)
return entity_mappings
elif provider.lower() == 'bedrock':
# For AWS Bedrock
response = client.generate_text(
if "anthropic" in model_name:
body = json.dumps({
"prompt": "\n\nHuman: " + prompt + "\n\nAssistant:",
"maxTokens": 500,
"temperature": 0.1,
"stopSequences": ["\n\nHuman:"]
})
elif "ai21" in model_name:
body = json.dumps({
"prompt": prompt,
"maxTokens": 500,
"temperature": 0.1,
"topP":1,
"stopSequences": ["<|END"]
})
else:
raise ValueError(f"Unsupported model for Amazon Bedrock: {model_name}")
response = client.invoke_model(
modelId=model_name,
prompt=prompt,
# Other parameters as needed
accept='application/json',
contentType='application/json',
body=body
)
generated_text = response.get('result', '').strip()
entity_mappings = parse_llm_output(generated_text)
return entity_mappings
response_body = response['body'].read().decode('utf-8')
response_json = json.loads(response_body)
if "result" in response_json:
generated_text = response_json['result']
elif "completion" in response_json:
generated_text= response_json['completion']
else:
raise Exception("Unexpected response format from Amazon Bedrock")

elif provider.lower() == 'ollama':
# For Ollama Client
response = client.generate(prompt)
generated_text = response.strip()
entity_mappings = parse_llm_output(generated_text)
return entity_mappings
response = client.chat(
model=model_name,
messages=[{'role':'user','content': prompt}]
)
generated_text = response["message"]["content"].strip()
else:
raise ValueError(f"Unsupported provider: {provider}")

if not generated_text:
raise ValueError("LLLM returned an empty response.")

# Now parse the generated_text to dictionary
entity_mappings = parse_llm_output(generated_text)
return entity_mappings

except Exception as e:
logging.error(f"Error generating entity mappings with {provider}: {e}")
raise Exception(f"LLM generation failed: {e}")
Expand All @@ -184,8 +243,6 @@ def integrate_llm_in_entity_matching(similarity_df, client, provider: str, model
Side Effects:
May involve additional API calls to the LLM provider.
"""
import pandas as pd
import streamlit as st

st.write("Enhancing entity matching using LLM...")

Expand Down Expand Up @@ -240,11 +297,21 @@ def parse_llm_output(output_text: str) -> Dict[str, Any]:
import re
try:
# Attempt to extract JSON from the output_text
json_text = re.search(r'\{.*\}', output_text, re.DOTALL).group()
parsed_output = json.loads(json_text)
return parsed_output
except (json.JSONDecodeError, AttributeError) as e:
raise ValueError(f"Failed to parse LLM output as JSON: {e}")
json_pattern = r'\{(?:[^{}]|(?R))*\}'
matches = re.findall(json_pattern, output_text, flags=re.DOTALL)
if matches:
# Attempt to parse each match until successful
for json_str in matches:
try:
parsed_output = json.loads(json_str)
return parsed_output
except json.JSONDecodeError:
continue
raise ValueError("Failed to parse JSON content from LLM output")
else:
raise ValueError("No JSON content found in LLM output")
except Exception as e:
raise ValueError(f"Error parsing LLM output: {e}")

def infer_true_parents(parent_names_list, parent_entity_type, client, provider, model_name):
"""
Expand Down

0 comments on commit 0ca5f88

Please sign in to comment.