Skip to content

Commit

Permalink
deploy: Reformat validate-db using Black.
Browse files Browse the repository at this point in the history
Issue: #96

Signed-off-by: Nikolay Martyanov <[email protected]>
  • Loading branch information
OhmSpectator committed Nov 29, 2023
1 parent d451245 commit a2e5d96
Showing 1 changed file with 64 additions and 25 deletions.
89 changes: 64 additions & 25 deletions deployment/validate-db/validate-db.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
def save_error_report(region_id, gadm_uid, hierarchy, feedback):
try:
with open(error_report_file, "a") as file:
file.write(f"Region ID: {region_id}, GADM ID: {gadm_uid}\n{hierarchy}\n{feedback}\n\n")
file.write(
f"Region ID: {region_id}, GADM ID: {gadm_uid}\n{hierarchy}\n{feedback}\n\n"
)
except IOError as e:
print(f"Error during file ({error_report_file}) operation: {e}")

Expand All @@ -36,15 +38,19 @@ def add_to_cache(regions_id):
def red_text(text):
return f"\033[38;5;9m{text}\033[0m"


def orange_text(text):
return f"\033[38;5;208m{text}\033[0m"


def slight_yellow_text(text):
return f"\033[38;5;220m{text}\033[0m"


def green_text(text):
return f"\033[92m{text}\033[0m"


def print_error_title(title, severity):
if severity == "high":
print(red_text(title))
Expand Down Expand Up @@ -72,11 +78,15 @@ def get_hierarchy(cur, region_id):

# Building the path from the given region up to the root
while region_id:
cur.execute("SELECT name, parent_region_id FROM regions WHERE id = %s", (region_id,))
cur.execute(
"SELECT name, parent_region_id FROM regions WHERE id = %s", (region_id,)
)
row = cur.fetchone()
if row:
name, parent_region_id = row
path_to_root.insert(0, name) # Insert at the beginning to build the path bottom-up.
path_to_root.insert(
0, name
) # Insert at the beginning to build the path bottom-up.
if original_region_parent_id is None:
original_region_parent_id = parent_region_id
region_id = parent_region_id
Expand All @@ -87,8 +97,10 @@ def get_hierarchy(cur, region_id):
siblings = []
siblings_ids = []
if original_region_parent_id is not None:
cur.execute("SELECT name, id FROM regions WHERE parent_region_id = %s AND id != %s",
(original_region_parent_id, original_region_id))
cur.execute(
"SELECT name, id FROM regions WHERE parent_region_id = %s AND id != %s",
(original_region_parent_id, original_region_id),
)
rows = cur.fetchall()
siblings = [row[0] for row in rows]
siblings_ids = [row[1] for row in rows]
Expand All @@ -114,18 +126,33 @@ def get_hierarchy(cur, region_id):
db_name = os.getenv("DB_NAME")
db_user = os.getenv("DB_USER")
db_password = os.getenv("DB_PASSWORD")
db_host = os.getenv("DB_HOST", 'localhost')
db_host = os.getenv("DB_HOST", "localhost")
openai.api_key = os.getenv("OPENAI_API_KEY")

# Check that the DB credentials were provided
if not all([db_name, db_user, db_password, openai.api_key]):
print("Error: DB_NAME, DB_USER, DB_PASSWORD, and OPENAI_API_KEY must be provided in .env")
print(
"Error: DB_NAME, DB_USER, DB_PASSWORD, and OPENAI_API_KEY must be provided in .env"
)
sys.exit(1)

# Setup argument parser
parser = argparse.ArgumentParser(description="Script to validate region hierarchies with OpenAI API.")
parser.add_argument('-c', '--cheap', action='store_true', help='Use the gpt-3.5-turbo model instead of gpt-4.')
parser.add_argument('-n', '--num-regions', type=int, default=10, help='Number of random regions to check.')
parser = argparse.ArgumentParser(
description="Script to validate region hierarchies with OpenAI API."
)
parser.add_argument(
"-c",
"--cheap",
action="store_true",
help="Use the gpt-3.5-turbo model instead of gpt-4.",
)
parser.add_argument(
"-n",
"--num-regions",
type=int,
default=10,
help="Number of random regions to check.",
)

# Parse arguments
args = parser.parse_args()
Expand All @@ -137,38 +164,45 @@ def get_hierarchy(cur, region_id):
num_regions_to_check = args.num_regions

# Connect to your database
conn = psycopg2.connect(dbname=db_name, user=db_user, password=db_password, host=db_host)
conn = psycopg2.connect(
dbname=db_name, user=db_user, password=db_password, host=db_host
)
cur = conn.cursor()


checked_regions = load_cache()


# Generate a WHERE clause to exclude the regions that were already checked
where_clause = "WHERE id NOT IN (" + ",".join(checked_regions) + ")" if checked_regions else ""
where_clause = (
"WHERE id NOT IN (" + ",".join(checked_regions) + ")" if checked_regions else ""
)
# Generate a list of random region IDs to check, excluding the ones that were already checked, by SQL
cur.execute(f"""
cur.execute(
f"""
SELECT id FROM regions
{where_clause}
ORDER BY random()
LIMIT {num_regions_to_check}
""")
"""
)
region_ids = [row[0] for row in cur.fetchall()]

error_mark = "WARNING"

initial_prompt = (
"Check out our region hierarchy. It is very important to me. Reply in JSON."
"If it's perfect, reply with `{\"status\": \"valid\"}'."
"Notice an issue? Point it out with `{\"status\": \"error\", \"severity\": \"low|medium|high\", \"detail\": \"Explain the issue\"}`."
"The hierarchy is provided in the following format: Parent -> Parent -> Parent -> Sibling, Sibling, Sibling."
"Check out our region hierarchy. It is very important to me. Reply in JSON."
'If it\'s perfect, reply with `{"status": "valid"}\'.'
'Notice an issue? Point it out with `{"status": "error", "severity": "low|medium|high", "detail": "Explain the issue"}`.'
"The hierarchy is provided in the following format: Parent -> Parent -> Parent -> Sibling, Sibling, Sibling."
)

client = openai.Client(api_key=openai.api_key)

input_tokens = 0
output_tokens = 0


# Validate region hierarchy data for selected regions
def valid_schema(json_feedback):
if "status" not in json_feedback:
Expand All @@ -191,22 +225,25 @@ def valid_schema(json_feedback):
for region_id in region_ids:
cur.execute("SELECT gadm_uid FROM regions WHERE id = %s", (region_id,))
result = cur.fetchone() # Store the result of fetchone
gadm_uid = result[0] if result else None # Check if result is not None before subscripting
gadm_uid = (
result[0] if result else None
) # Check if result is not None before subscripting
hierarchy, siblings = get_hierarchy(cur, region_id)
title_message = f"Validating region hierarchy: {hierarchy}"
print(f"{'-' * 80}")
print(f"Validating region hierarchy: {hierarchy}") # Tab at the beginning for separation
print(
f"Validating region hierarchy: {hierarchy}"
) # Tab at the beginning for separation
try:
completion = client.chat.completions.create(
model=model_to_use,
messages=[
{"role": "system", "content": initial_prompt},
{"role": "user", "content": hierarchy}
{"role": "user", "content": hierarchy},
],
response_format = {"type": "json_object"},
response_format={"type": "json_object"},
n=1,

max_tokens=150
max_tokens=150,
)
feedback = completion.choices[0].message.content
input_tokens += completion.usage.prompt_tokens
Expand All @@ -224,7 +261,9 @@ def valid_schema(json_feedback):

if json_feedback["status"] == "error":
# Red text for the error message part only
print_error_title(f"Potential error in region id: {region_id}", json_feedback["severity"])
print_error_title(
f"Potential error in region id: {region_id}", json_feedback["severity"]
)
# Normal color for feedback
print(json_feedback["detail"])
save_error_report(region_id, gadm_uid, hierarchy, json_feedback["detail"])
Expand Down

0 comments on commit a2e5d96

Please sign in to comment.