Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored #1

Merged
merged 2 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ __pycache__/
--*.*
--*/

ignore*.py

!example.env

test_pri_*.*
Expand Down
1 change: 0 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@
}



logger = logging.getLogger(__name__)


Expand Down
4 changes: 1 addition & 3 deletions evdschat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from evdschat.core.chat import chat, chat_console

__all__ = [
chat, chat_console
]
__all__ = [chat, chat_console]
89 changes: 74 additions & 15 deletions evdschat/common/akeys.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
from abc import ABC
import os
from pathlib import Path
from typing import Union
from typing import Union, Dict
import time

from evdschat.common.globals import WARNING_SLEEP_SECONDS


class ErrorApiKey(Exception):
Expand All @@ -33,13 +36,21 @@ def __init__(self, message="There is an issue with the provided API key."):


class ApiKey(ABC):
def __init__(self, key: str) -> None:
def __init__(self, key: str, key_name: str = 'ApiKey') -> None:
self.key = key
self.key_name = key_name
self.check()

def __str__(self):
return self.key

def msg_before_raise(self):
showApiKeyMessage(self.__class__.__name__)
# create_env_example_file()

def check(self):
if isinstance(self.key, type(None)):
raise ErrorApiKey('Api key not set. Please see the documentation.')
raise ErrorApiKey("Api key not set. Please see the documentation.")
if not isinstance(self.key, str) or len(str(self.key)) < 5:
raise ErrorApiKey(f"Api key {self.key} is not a valid key")
return True
Expand All @@ -51,16 +62,52 @@ def set_key(self, key: str):
self.key = key


def sleep(number: int):
time.sleep(number)


def showApiKeyMessage(cls_name: str) -> None:
msg = f"""
{cls_name} not found.

create `.env` file and put necessary API keys for EVDS and {cls_name}
see documentation for details.

"""

print(msg)
sleep(WARNING_SLEEP_SECONDS)


def write_env_example(file_name: Path):
content = (
"\nOPENAI_API_KEY=sk-proj-ABCDEFGIJKLMNOPQRSTUXVZ\nEVDS_API_KEY=ABCDEFGIJKLMNOP"
)
with open(file_name, "w") as f:
f.write(content)
print("Example .env file was created.")
sleep(WARNING_SLEEP_SECONDS)


def create_env_example_file():
file_name = Path(".env")
if not file_name.exists():
write_env_example(file_name)


class OpenaiApiKey(ApiKey):
def __init__(self, key: str) -> None:
super().__init__(key)
self.key = key
self.check()

def check(self) -> Union[bool, None]:
self.key_name = 'openai_api_key'
# self.check()

def check(self, raise_=True) -> Union[bool, None]:
if not str(self.key).startswith("sk-") and len(str(self.key)) < 6:
raise ErrorApiKey(f"{self.key} is not a valid key")
self.msg_before_raise()
if raise_:
raise ErrorApiKey(f"{self.key} is not a valid key")
return False
return True


Expand All @@ -70,7 +117,7 @@ class EvdsApiKey(ApiKey): ...
class MistralApiKey(ApiKey): ...


def load_api_keys() -> Union[dict[str, str], None]:
def load_api_keys() -> Dict[str, OpenaiApiKey | EvdsApiKey]:
from dotenv import load_dotenv

env_file = Path(".env")
Expand All @@ -83,19 +130,31 @@ def load_api_keys() -> Union[dict[str, str], None]:
}


def get_openai_key():
def load_api_keys_string() -> Dict[str, str]:
from dotenv import load_dotenv

env_file = Path(".env")
load_dotenv(env_file)
openai_api_key = os.getenv("OPENAI_API_KEY")
evds_api_key = os.getenv("EVDS_API_KEY")
return {
"OPENAI_API_KEY": openai_api_key,
"EVDS_API_KEY": evds_api_key,
}


def get_openai_key() -> OpenaiApiKey:
d = load_api_keys()
return d["OPENAI_API_KEY"].key
return d["OPENAI_API_KEY"]


def get_openai_key_string() -> str | None:
d = load_api_keys_string()
return d["OPENAI_API_KEY"]

# @dataclass

class ApiKeyManager(BaseModel):
api_key: ApiKey = Field(default_factory=lambda: ApiKey())

class Config:
arbitrary_types_allowed = True

# def __get_pydantic_core_schema__(cls, handler):
# # Generate a schema if necessary, or skip it
# return handler.generate_schema(cls)
31 changes: 17 additions & 14 deletions evdschat/common/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,43 +4,47 @@
from pathlib import Path
from importlib import resources
from typing import Union
from .github_actions import PytestTesting
from .github_actions import PytestTesting


class PostParams(ctypes.Structure):
_fields_ = [
("url", ctypes.c_char_p),
("prompt", ctypes.c_char_p),
("api_key", ctypes.c_char_p),
("proxy_url", ctypes.c_char_p)
("proxy_url", ctypes.c_char_p),
]

def get_exec_file(test = False ) -> Path :

def get_exec_file(test=False) -> Path:

executable_name = "libpost_request.so"
if platform.system() == "Windows":
executable_name = "libpost_request.dll"

if test or PytestTesting().is_testing():
executable_path = Path(".") / executable_name
if executable_path.is_file() :
if executable_path.is_file():
return executable_path
return False

def check_c_executable(test = False ) -> Union[Path, bool]:
executable_name= get_exec_file(test )
return False


def check_c_executable(test=False) -> Union[Path, bool]:
executable_name = get_exec_file(test)
if not executable_name:
return False
return False
try:
with resources.path("evdschat", executable_name) as executable_path:
if executable_path.is_file() and os.access(executable_path, os.X_OK):
return executable_path
except FileNotFoundError:
return False


lib_path = check_c_executable()
if lib_path:
lib = ctypes.CDLL(lib_path)

lib.post_request.argtypes = [ctypes.POINTER(PostParams)]
lib.post_request.restype = ctypes.c_char_p

Expand All @@ -54,16 +58,15 @@ def c_caller(params):

def c_caller_main(prompt, api_key, url, proxy=None):
prompt = prompt.replace("\n", " ")

params = PostParams(
url=url.encode("utf-8"),
prompt=prompt.encode("utf-8"),
api_key=api_key.encode("utf-8"),
proxy_url=proxy.encode("utf-8") if proxy else None
proxy_url=proxy.encode("utf-8") if proxy else None,
)

return c_caller(params)



else:
c_caller_main = None
8 changes: 7 additions & 1 deletion evdschat/common/github_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,23 @@
# limitations under the License.

import sys


class GithubActions:
def is_testing(self):
return "hostedtoolcache" in sys.argv[0]


class PytestTesting:
def is_testing(self):
# print(" sys.argv[0]" , sys.argv[0])
return "pytest" in sys.argv[0]


def get_input(msg, default=None):
if GithubActions().is_testing() or PytestTesting().is_testing():
if not default:
print("currently testing with no default ")
return False
return default
return input(msg)
return input(msg)
3 changes: 3 additions & 0 deletions evdschat/common/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

WARNING_SLEEP_SECONDS = 6
DEFAULT_CHAT_API_URL = "https://evdspychat-dev2-1.onrender.com/api/ask"


def global_mock():
template = """
Expand Down
1 change: 0 additions & 1 deletion evdschat/core/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field

from pydantic import BaseModel, Field

Expand Down
19 changes: 12 additions & 7 deletions evdschat/core/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ class GotUndefinedResult(BaseException): ...


def chat(
prompt: str,
getter: ModelAbstract = OpenAI(),
debug=False,
test=False,
force=False,
prompt: str,
getter: ModelAbstract = None,
debug=False,
test=False,
force=False,
) -> Union[Tuple[ResultChat, Notes], None]:
"""
Function to process the chat prompt and return the result.
Expand All @@ -45,6 +45,11 @@ def chat(
:return: DataFrame or Result Instance with .data (DataFrame), .metadata (DataFrame), and .to_excel (Callable).
"""

if getter is None:
getter = OpenAI()



if not force and PytestTesting().is_testing():
test = True

Expand Down Expand Up @@ -74,8 +79,8 @@ def chat(
raise GotUndefinedResult()
result, notes = res
if isinstance(result, ResultChat):
return result, notes
raise NotImplementedError("Unknown Result type ")
return result, notes
raise NotImplementedError("Unknown Result type ")


def chat_console() -> None:
Expand Down
Loading
Loading