diff --git a/askchat/__init__.py b/askchat/__init__.py index a60dee9..580ab93 100644 --- a/askchat/__init__.py +++ b/askchat/__init__.py @@ -2,11 +2,13 @@ __author__ = """Rex Wang""" __email__ = '1073853456@qq.com' -__version__ = '1.1.2' +__version__ = '1.1.3' import asyncio from pathlib import Path import click +from dotenv import set_key +import os # Main environment file CONFIG_PATH = Path.home() / ".askchat" @@ -14,6 +16,24 @@ MAIN_ENV_PATH = Path.home() / '.askchat' / '.env' ENV_PATH = Path.home() / '.askchat' / 'envs' +raw_env_text = f""""# Description: Env file for askchat. +# Current version: {__version__} + +# The base url of the API (with suffix /v1) +# This will override OPENAI_API_BASE_URL if both are set. +OPENAI_API_BASE='' + +# The base url of the API (without suffix /v1) +OPENAI_API_BASE_URL='' + +# Your API key +OPENAI_API_KEY='' + +# The default model name +# You can use `askchat --all-valid-models` to see supported models +OPENAI_API_MODEL='' +""" + # Autocompletion # environment name completion class EnvNameCompletionType(click.ParamType): @@ -42,20 +62,36 @@ async def show_resp(chat, **options): print() # add a newline if the message doesn't end with one return msg -def write_config(config_file, api_key, model, base_url, api_base): - """Write the environment variables to a config file.""" - def write_var(f, var, value, desc): - value = value if value else "" - f.write(f"\n\n# {desc}\n") - f.write(f'{var}="{value}"') +def set_keys(config_file, keys): + """Set multiple keys in the config file.""" + for key, value in keys.items(): + if value: + set_key(config_file, key, value) + +def raw_config(config_file:str): + """Empty config file.""" + if not CONFIG_PATH.exists(): + CONFIG_PATH.mkdir(parents=True) with open(config_file, "w") as f: - f.write("#Description: Env file for askchat.\n" +\ - "#Current version: " + __version__) - # write the environment table - write_var(f, "OPENAI_API_BASE", api_base, "The base url of the API (with suffix /v1)" +\ - "\n# This will override OPENAI_API_BASE_URL if both are set.") - write_var(f, "OPENAI_API_BASE_URL", base_url, "The base url of the API (without suffix /v1)") - - write_var(f, "OPENAI_API_KEY", api_key, "Your API key") - write_var(f, "OPENAI_API_MODEL", model, "The model name\n" +\ - "# You can use `askchat --all-valid-models` to see supported models") \ No newline at end of file + f.write(raw_env_text) + +def init_config(config_file:str): + """Initialize the config file with the current environment variables.""" + raw_config(config_file) + set_keys(config_file, { + "OPENAI_API_KEY": os.getenv("OPENAI_API_KEY"), + "OPENAI_API_MODEL": os.getenv("OPENAI_API_MODEL"), + "OPENAI_API_BASE_URL": os.getenv("OPENAI_API_BASE_URL"), + "OPENAI_API_BASE": os.getenv("OPENAI_API_BASE"), + }) + +def write_config(config_file, api_key, model, base_url, api_base, overwrite=False): + """Write the environment variables to a config file.""" + if overwrite or not config_file.exists(): + raw_config(config_file) + set_keys(config_file, { + "OPENAI_API_KEY": api_key, + "OPENAI_API_MODEL": model, + "OPENAI_API_BASE_URL": base_url, + "OPENAI_API_BASE": api_base, + }) \ No newline at end of file diff --git a/askchat/askenv.py b/askchat/askenv.py index 75e2787..9afd8d3 100644 --- a/askchat/askenv.py +++ b/askchat/askenv.py @@ -41,7 +41,7 @@ def new(name, api_key, base_url, api_base, model): click.confirm("Do you want to continue?", abort=True) else: click.echo(f"Environment '{name}' created.") - write_config(config_path, api_key, model, base_url, api_base) + write_config(config_path, api_key, model, base_url, api_base, overwrite=True) @cli.command() @click.argument('name', required=False, type=EnvNameCompletionType()) @@ -120,16 +120,10 @@ def config(name, api_key, base_url, api_base, model): return config_path = ENV_PATH / f'{name}.env' if name else MAIN_ENV_PATH if not config_path.exists(): - click.echo(f"Environment '{config_path}' not found.") + click.echo(f"Environment '{config_path}' not found." +\ + "Use `askenv new` to create a new environment." ) return - if api_key: - set_key(config_path, "OPENAI_API_KEY", api_key) - if base_url: - set_key(config_path, "OPENAI_API_BASE_URL", base_url) - if api_base: - set_key(config_path, "OPENAI_API_BASE", api_base) - if model: - set_key(config_path, "OPENAI_API_MODEL", model) + write_config(config_path, api_key, model, base_url, api_base) click.echo(f"Environment {config_path} updated.") if __name__ == '__main__': diff --git a/askchat/cli.py b/askchat/cli.py index c55a62b..c6b343a 100644 --- a/askchat/cli.py +++ b/askchat/cli.py @@ -8,7 +8,7 @@ from chattool import Chat, debug_log from pathlib import Path from askchat import ( - show_resp, write_config + show_resp, write_config, init_config , ENV_PATH, MAIN_ENV_PATH , CONFIG_PATH, CONFIG_FILE , EnvNameCompletionType, ChatFileCompletionType @@ -37,12 +37,10 @@ def setup(): def generate_config_callback(ctx, param, value): """Generate a configuration file by environment table.""" if not value: return - api_key, model = os.getenv("OPENAI_API_KEY"), os.getenv("OPENAI_API_MODEL") - base_url, api_base = os.getenv("OPENAI_API_BASE_URL"), os.getenv("OPENAI_API_BASE") # save the config file if os.path.exists(CONFIG_FILE): click.confirm(f"Overwrite the existing configuration file {CONFIG_FILE}?", abort=True) - write_config(CONFIG_FILE, api_key, model, base_url, api_base) + init_config(CONFIG_FILE) print("Created config file at", CONFIG_FILE) ctx.exit() diff --git a/setup.py b/setup.py index 8f67ed6..260b13c 100644 --- a/setup.py +++ b/setup.py @@ -4,12 +4,12 @@ from setuptools import setup, find_packages -VERSION = '1.1.2' +VERSION = '1.1.3' with open('README.md') as readme_file: readme = readme_file.read() -requirements = ['chattool>=3.1.3', "python-dotenv>=0.17.0", 'Click>=8.0'] +requirements = ['chattool>=3.1.4', "python-dotenv>=0.17.0", 'Click>=8.0'] test_requirements = ['pytest>=3'] diff --git a/tests/test_askenv.py b/tests/test_askenv.py index 30d75b1..30b1d95 100644 --- a/tests/test_askenv.py +++ b/tests/test_askenv.py @@ -33,7 +33,7 @@ def test_overwrite_environment_confirm(runner, setup_env): assert "Do you want to continue?" in result.output # Verify the environment was overwritten by checking if the new API key is in the file with open(config_path) as f: - assert 'OPENAI_API_KEY="456"' in f.read() + assert "OPENAI_API_KEY='456'" in f.read() def test_list_initially_empty(runner, setup_env): """Ensure no environments are listed when none have been created."""