diff --git a/ner/ner_solution.py b/ner/ner_solution.py index 59f6981..a2c6374 100644 --- a/ner/ner_solution.py +++ b/ner/ner_solution.py @@ -3,6 +3,7 @@ import os import pandas as pd import requests +import boto3 import streamlit as st import tempfile @@ -39,50 +40,42 @@ def convert_to_json_schema(yaml_str): return ret_str -def transcribe_audio(file_path: str, octoai_token: str): +def transcribe_audio(encoded_audio: str, octoai_token: str): """ Takes the file path of an audio file and transcribes it to text. Returns a string with the transcribed text. """ - with open(file_path, "rb") as f: - encoded_audio = str(base64.b64encode(f.read()), "utf-8") - reply = requests.post( - "https://whisper2-or1pkb9b656p.octoai.run/predict", - headers={ - "Content-Type": "application/json", - "Authorization": f"Bearer {octoai_token}", - }, - json={"audio": encoded_audio}, - timeout=300, - ) - try: - transcript = reply.json()["transcription"] - except Exception as e: - print(e) - print(reply.text) - raise ValueError("The transcription could not be completed.") + reply = requests.post( + "https://whisper2-or1pkb9b656p.octoai.run/predict", + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {octoai_token}", + }, + json={"audio": encoded_audio}, + timeout=300, + ) + try: + transcript = reply.json()["transcription"] + except Exception as e: + print(e) + print(reply.text) + raise ValueError("The transcription could not be completed.") return transcript -def file_to_base64(file_path): - with open(file_path, "rb") as image_file: - return base64.b64encode(image_file.read()).decode("utf-8") - - -def process_image(file_path: str, octoai_token: str): - # Convert the images to base64 strings - base64_str = f"data:image/png;base64,{file_to_base64(file_path)}" +def process_image(encoded_image: str, octoai_token: str, yaml: str): + print(yaml) messages = [ { "role": "user", "content": [ { "type": "text", - "text": "Describe what you see in the image in great detail", + "text": "Describe what you see in the image in great detail. Be as exhaustive and factual as possible. Provide detail according to the JSON description below:\n{}".format(yaml), }, - {"type": "image_url", "image_url": {"url": base64_str}}, + {"type": "image_url", "image_url": {"url": encoded_image}}, ], } ] @@ -139,7 +132,7 @@ def submit_new_token(): st.session_state.octoai_api_key = st.session_state.token_text_input -st.set_page_config(layout="wide", page_title="NER Playground") +st.set_page_config(layout="wide", page_title="Multi-Modal Data Extractor") if "octoai_api_key" not in st.session_state: st.session_state.octoai_api_key = os.environ.get("OCTOAI_API_KEY", None) @@ -160,24 +153,41 @@ def submit_new_token(): """ ) else: - with st.form("input-form", clear_on_submit=True, border=True): - tab1, tab2 = st.tabs(["Files", "URLs"]) + with st.form("input-form", clear_on_submit=False, border=True): + tab1, tab2, tab3 = st.tabs(["Local Files", "URLs", "S3"]) + # Local files with tab1: upload_files = st.file_uploader( - "Upload your files here", + "Upload your PDFs/audio/JPEG files here", type=[".pdf", ".mp3", ".mp4", ".wav", ".jpg", ".jpeg"], accept_multiple_files=True, key="upload_files", ) st.caption("Click on submit after uploading to process the files.") + # URLs with tab2: website_url = st.text_input( - "Enter the URL of the website to scrape", key="website_url" + "Enter the URL(s) of the website to scrape", key="website_url" ) st.caption("Use comma for multiple URLs.") + # S3 + with tab3: + aws_access_key_id = st.text_input( + "AWS Access Key ID", value="AWSACCESSKEYID" + ) + aws_secret_access_key = st.text_input( + "AWS Secret Key", type="password", value="asdf" + ) + aws_s3_bucket = st.text_input( + "AWS S3 bucket", value="bucket-name" + ) + aws_s3_bucket_path = st.text_input( + "Path to directory to process", value="path/to/dir/" + ) + st.form_submit_button("Submit", on_click=submit_onclick) st.write( @@ -187,8 +197,8 @@ def submit_new_token(): "[![Open in GitHub Codespaces](https://github.com/codespaces/badge.svg)](https://codespaces.new/octoml/octoai-solutions)" ) -st.write("## NER Playground") -st.caption("Named Entity Recognition Playground.") +st.write("## Multi-Modal Data Extractor") +st.caption("Powered by OctoAI.") ################################################# # Section 1: Inputs @@ -207,12 +217,14 @@ def submit_new_token(): executive_summary: desc: executive summary of the document """ +st.session_state["yaml_format"] = yaml_format def update_json_schema(code): # Prepare the JSON schema json_schema = convert_to_json_schema(code) st.session_state["json_schema"] = json_schema + st.session_state["yaml_format"] = code if "json_schema" not in st.session_state: @@ -239,7 +251,6 @@ def update_json_schema(code): ] code_response = code_editor(code=yaml_format, lang="yaml", buttons=custom_btns) if code_response["text"]: - print(code_response["text"]) update_json_schema(code_response["text"]) if not st.session_state.get("process_new_inputs", False) and ( @@ -291,15 +302,26 @@ def update_json_schema(code): or upload_file.name.endswith(".mp4") or upload_file.name.endswith(".wav") ): - doc_str = transcribe_audio( - tf.name, st.session_state.octoai_api_key - ) - elif upload_file.name.endswith("jpg") or upload_file.name.endswith( - "jpeg" + # Convert the image to base64 string + with open(tf.name, "rb") as f: + encoded_audio = str(base64.b64encode(f.read()), "utf-8") + doc_str = transcribe_audio( + encoded_audio, st.session_state.octoai_api_key + ) + # Image file handling + elif ( + upload_file.name.endswith("jpg") + or upload_file.name.endswith("jpeg") ): - doc_str = process_image( - tf.name, st.session_state.octoai_api_key - ) + # Convert the images to base64 string + with open(tf.name, "rb") as f: + encoded_image = base64.b64encode(image_file.read()).decode("utf-8") + encoded_image = f"data:image/png;base64,{encoded_image}" + doc_str = process_image( + encoded_image, + st.session_state.octoai_api_key, + str(yaml.load(st.session_state["yaml_format"], Loader=yaml.SafeLoader)) + ) st.session_state.doc_str.append(doc_str) elif website_url: @@ -345,6 +367,66 @@ def update_json_schema(code): f"An error occurred while processing {got_error}. Please refresh and try again." ) + elif aws_access_key_id and aws_secret_access_key: + + # Create an S3 client + s3_client = boto3.client( + 's3', + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key + ) + # Get the list in the bucket directory + result = s3_client.list_objects( + Bucket=aws_s3_bucket, + Prefix=aws_s3_bucket_path, + Delimiter='/' + ) + + if len(result.get('Contents')) == 1: + spinner_message = f"Processing {result.get('Contents')[0].get('Key')} into Markdown..." + else: + spinner_message = f"Processing {len(result.get('Contents'))-1} files into Markdown..." + # Preprocess documents + with st.status(spinner_message): + for bucket_file in result.get('Contents'): + f_name = bucket_file.get('Key') + data = s3_client.get_object(Bucket=aws_s3_bucket, Key=f_name) + if f_name == bucket_file: + continue + # PDF handling + if f_name.endswith(".pdf"): + # Read in first document + documents = parser.load_data( + data['Body'].read(), + extra_info={"file_name": f_name} + ) + doc_str = "" + for document in documents: + doc_str += document.text + doc_str += "\n" + st.session_state.doc_str.append(doc_str) + # Audio file handling + elif ( + f_name.endswith(".mp3") + or f_name.endswith(".mp4") + or f_name.endswith(".wav") + ): + encoded_audio = str(base64.b64encode(data['Body'].read()), "utf-8") + doc_str = transcribe_audio( + encoded_audio, st.session_state.octoai_api_key + ) + st.session_state.doc_str.append(doc_str) + elif f_name.endswith("jpg") or f_name.endswith("jpeg"): + # Convert the images to base64 string + encoded_image = base64.b64encode(data['Body'].read()).decode("utf-8") + encoded_image = f"data:image/png;base64,{encoded_image}" + doc_str = process_image( + encoded_image, + st.session_state.octoai_api_key, + str(yaml.load(st.session_state["yaml_format"], Loader=yaml.SafeLoader)) + ) + st.session_state.doc_str.append(doc_str) + ################################################# # Section 3: Processing the outputs @@ -378,7 +460,7 @@ def update_json_schema(code): """ data = { - "model": "meta-llama-3.1-70b-instruct", + "model": "meta-llama-3.1-405b-instruct", "messages": [ { "role": "system", diff --git a/ner/requirements.txt b/ner/requirements.txt index 8b4d1ba..e61a20a 100644 --- a/ner/requirements.txt +++ b/ner/requirements.txt @@ -5,3 +5,4 @@ llama-parse streamlit-code-editor firecrawl-py snowflake-connector-python[pandas] +boto3 diff --git a/ner/yaml_examples/electronic_health_records.yaml b/ner/yaml_examples/electronic_health_records.yaml new file mode 100644 index 0000000..aa92774 --- /dev/null +++ b/ner/yaml_examples/electronic_health_records.yaml @@ -0,0 +1,9 @@ +# Describe the fields of information in YAML format +executive_summary: + desc: one sentence executive summary of the transcript +symptoms_list: + desc: comma separated list of symptoms mentioned in the transcript +medication_list: + desc: comma separated list of medication mentioned in the transcript +procedures_list: + desc: comma separated list of procedures mentioned in the transcript diff --git a/ner/yaml_examples/expense_management.yaml b/ner/yaml_examples/expense_management.yaml new file mode 100644 index 0000000..c02ff69 --- /dev/null +++ b/ner/yaml_examples/expense_management.yaml @@ -0,0 +1,13 @@ +# Describe the fields of information in YAML format +merchant_name: + desc: name of merchant +date: + desc: date on the receipt +type: + desc: type of expense (e.g. food, transportation, lodging...) +list_of_items: + desc: comma separated list of items on the bill +total: + desc: bill total in dollars +includes_alcohol: + desc: yes if the bill includes alcohol, else no \ No newline at end of file diff --git a/ner/yaml_examples/financial_records.yaml b/ner/yaml_examples/financial_records.yaml new file mode 100644 index 0000000..c53ce07 --- /dev/null +++ b/ner/yaml_examples/financial_records.yaml @@ -0,0 +1,11 @@ +# Describe the fields of information in YAML format +product_revenue: + desc: product revenue for the quarter +product_revenue_year_over_year_growth: + desc: product revenue year over year growth for the quarter +customers_over_one_mil: + desc: number of customer with trailing 12-month product revenue greater than 1 million +investor_contact: + desc: name of investor contact +investor_contact_email: + desc: email address of investor contact \ No newline at end of file diff --git a/ner/yaml_examples/vegetation_management.yaml b/ner/yaml_examples/vegetation_management.yaml new file mode 100644 index 0000000..699fb8c --- /dev/null +++ b/ner/yaml_examples/vegetation_management.yaml @@ -0,0 +1,7 @@ +# Describe the fields of information in YAML format +power_delivery_line: + desc: yes if power delivery line is present in image, else no +foliage: + desc: yes if foliage is present in image, else no +foliage_close_to_power_delivery_line: + desc: yes if foliage is in contact with, or too close to power delivery line in image, else no \ No newline at end of file