Skip to content

Commit

Permalink
Wd/embed (Chainlit#679)
Browse files Browse the repository at this point in the history
* relax fast api

* allow for custom fonts

* make cors configurable

* rename to copilot and serve copilot index.js


---------

Co-authored-by: SuperTurk <[email protected]>
  • Loading branch information
willydouhard and alimtunc authored Jan 19, 2024
1 parent f732544 commit 9c914c9
Show file tree
Hide file tree
Showing 102 changed files with 11,741 additions and 577 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ venv
.DS_Store

.chainlit
!cypress/e2e/**/*/.chainlit
!cypress/e2e/**/*/.chainlit/*
chainlit.md

cypress/screenshots
Expand Down
11 changes: 11 additions & 0 deletions backend/chainlit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional

from fastapi import Request, Response
from pydantic.dataclasses import dataclass
from starlette.datastructures import Headers

if TYPE_CHECKING:
Expand Down Expand Up @@ -294,6 +295,15 @@ def sleep(duration: int):
return asyncio.sleep(duration)


@dataclass()
class CopilotFunction:
name: str
args: Dict[str, Any]

def acall(self):
return context.emitter.send_call_fn(self.name, self.args)


__getattr__ = make_module_getattr(
{
"LangchainCallbackHandler": "chainlit.langchain.callbacks",
Expand All @@ -305,6 +315,7 @@ def sleep(duration: int):

__all__ = [
"user_session",
"CopilotFunction",
"Action",
"User",
"PersistedUser",
Expand Down
14 changes: 13 additions & 1 deletion backend/chainlit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from chainlit.logger import logger
from chainlit.version import __version__
from dataclasses_json import DataClassJsonMixin
from pydantic.dataclasses import dataclass
from pydantic.dataclasses import Field, dataclass
from starlette.datastructures import Headers

if TYPE_CHECKING:
Expand Down Expand Up @@ -40,6 +40,7 @@
# Whether to enable telemetry (default: true). No personal data is collected.
enable_telemetry = true
# List of environment variables to be provided by each user to use the app.
user_env = []
Expand All @@ -49,6 +50,9 @@
# Enable third parties caching (e.g LangChain cache)
cache = false
# Authorized origins
allow_origins = ["*"]
# Follow symlink for asset mount (see https://github.com/Chainlit/chainlit/issues/317)
# follow_symlink = false
Expand Down Expand Up @@ -97,7 +101,12 @@
# The CSS file can be served from the public directory or via an external link.
# custom_css = "/public/test.css"
# Specify a custom font url.
# custom_font = "https://fonts.googleapis.com/css2?family=Inter:wght@400;500;700&display=swap"
# Override default MUI light theme. (Check theme.ts)
[UI.theme]
#font_family = "Inter, sans-serif"
[UI.theme.light]
#background = "#FAFAFA"
#paper = "#FFFFFF"
Expand Down Expand Up @@ -156,6 +165,7 @@ class Palette(DataClassJsonMixin):

@dataclass()
class Theme(DataClassJsonMixin):
font_family: Optional[str] = None
light: Optional[Palette] = None
dark: Optional[Palette] = None

Expand Down Expand Up @@ -188,6 +198,7 @@ class UISettings(DataClassJsonMixin):
theme: Optional[Theme] = None
# Optional custom CSS file that allows you to customize the UI
custom_css: Optional[str] = None
custom_font: Optional[str] = None


@dataclass()
Expand Down Expand Up @@ -217,6 +228,7 @@ class CodeSettings:

@dataclass()
class ProjectSettings(DataClassJsonMixin):
allow_origins: List[str] = Field(default_factory=lambda: ["*"])
enable_telemetry: bool = True
# List of environment variables to be provided by each user to use the app. If empty, no environment variables will be asked to the user.
user_env: Optional[List[str]] = None
Expand Down
1 change: 1 addition & 0 deletions backend/chainlit/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def init_http_context(
id=str(uuid.uuid4()),
token=auth_token,
user=user,
client_type="app",
user_env=user_env,
)
context = ChainlitContext(session)
Expand Down
63 changes: 41 additions & 22 deletions backend/chainlit/emitter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import uuid
from datetime import datetime
from typing import Any, Dict, List, Optional, Union, cast
from typing import Any, Dict, List, Literal, Optional, Union, cast

from chainlit.data import get_data_layer
from chainlit.element import Element, File
Expand Down Expand Up @@ -37,8 +37,8 @@ async def emit(self, event: str, data: Any):
"""Stub method to get the 'emit' property from the session."""
pass

async def ask_user(self):
"""Stub method to get the 'ask_user' property from the session."""
async def emit_call(self):
"""Stub method to get the 'emit_call' property from the session."""
pass

async def resume_thread(self, thread_dict: ThreadDict):
Expand All @@ -57,12 +57,11 @@ async def delete_step(self, step_dict: StepDict):
"""Stub method to delete a message in the UI."""
pass

async def send_ask_timeout(self):
"""Stub method to send a prompt timeout message to the UI."""
def send_timeout(self, event: Literal["ask_timeout", "call_fn_timeout"]):
"""Stub method to send a timeout to the UI."""
pass

async def clear_ask(self):
"""Stub method to clear the prompt from the UI."""
def clear(self, event: Literal["clear_ask", "clear_call_fn"]):
pass

async def init_thread(self, interaction: str):
Expand All @@ -78,6 +77,12 @@ async def send_ask_user(
"""Stub method to send a prompt to the UI and wait for a response."""
pass

async def send_call_fn(
self, name: str, args: Dict[str, Any], timeout=300, raise_on_timeout=False
) -> Optional[Dict[str, Any]]:
"""Stub method to send a call function event to the copilot and wait for a response."""
pass

async def update_token_count(self, count: int):
"""Stub method to update the token count for the UI."""
pass
Expand Down Expand Up @@ -137,9 +142,9 @@ def emit(self):
return self._get_session_property("emit")

@property
def ask_user(self):
"""Get the 'ask_user' property from the session."""
return self._get_session_property("ask_user")
def emit_call(self):
"""Get the 'emit_call' property from the session."""
return self._get_session_property("emit_call")

def resume_thread(self, thread_dict: ThreadDict):
"""Send a thread to the UI to resume it"""
Expand All @@ -157,15 +162,11 @@ def delete_step(self, step_dict: StepDict):
"""Delete a message in the UI."""
return self.emit("delete_message", step_dict)

def send_ask_timeout(self):
"""Send a prompt timeout message to the UI."""

return self.emit("ask_timeout", {})
def send_timeout(self, event: Literal["ask_timeout", "call_fn_timeout"]):
return self.emit(event, {})

def clear_ask(self):
"""Clear the prompt from the UI."""

return self.emit("clear_ask", {})
def clear(self, event: Literal["clear_ask", "clear_call_fn"]):
return self.emit(event, {})

async def flush_thread_queues(self, interaction: str):
if data_layer := get_data_layer():
Expand Down Expand Up @@ -229,8 +230,8 @@ async def send_ask_user(

try:
# Send the prompt to the UI
user_res = await self.ask_user(
{"msg": step_dict, "spec": spec.to_dict()}, spec.timeout
user_res = await self.emit_call(
"ask", {"msg": step_dict, "spec": spec.to_dict()}, spec.timeout
) # type: Optional[Union["StepDict", "AskActionResponse", List["FileReference"]]]

# End the task temporarily so that the User can answer the prompt
Expand Down Expand Up @@ -279,16 +280,34 @@ async def send_ask_user(
self.session.has_first_interaction = True
await self.init_thread(interaction=interaction)

await self.clear_ask()
await self.clear("clear_ask")
return final_res
except TimeoutError as e:
await self.send_ask_timeout()
await self.send_timeout("ask_timeout")

if raise_on_timeout:
raise e
finally:
await self.task_start()

async def send_call_fn(
self, name: str, args: Dict[str, Any], timeout=300, raise_on_timeout=False
) -> Optional[Dict[str, Any]]:
"""Stub method to send a call function event to the copilot and wait for a response."""
try:
call_fn_res = await self.emit_call(
"call_fn", {"name": name, "args": args}, timeout
) # type: Dict

await self.clear("clear_call_fn")
return call_fn_res
except TimeoutError as e:
await self.send_timeout("call_fn_timeout")

if raise_on_timeout:
raise e
return None

def update_token_count(self, count: int):
"""Update the token count for the UI."""

Expand Down
2 changes: 1 addition & 1 deletion backend/chainlit/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ class AskMessageBase(MessageBase):
async def remove(self):
removed = await super().remove()
if removed:
await context.emitter.clear_ask()
await context.emitter.clear("clear_ask")


class AskUserMessage(AskMessageBase):
Expand Down
41 changes: 33 additions & 8 deletions backend/chainlit/server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import glob
import json
import mimetypes
import re
import shutil
import urllib.parse
from typing import Any, Optional, Union
Expand Down Expand Up @@ -135,18 +136,19 @@ async def watch_files_for_changes():
os._exit(0)


def get_build_dir():
local_build_dir = os.path.join(PACKAGE_ROOT, "frontend", "dist")
packaged_build_dir = os.path.join(BACKEND_ROOT, "frontend", "dist")
def get_build_dir(local_target: str, packaged_target: str):
local_build_dir = os.path.join(PACKAGE_ROOT, local_target, "dist")
packaged_build_dir = os.path.join(BACKEND_ROOT, packaged_target, "dist")
if os.path.exists(local_build_dir):
return local_build_dir
elif os.path.exists(packaged_build_dir):
return packaged_build_dir
else:
raise FileNotFoundError("Built UI dir not found")
raise FileNotFoundError(f"{local_target} built UI dir not found")


build_dir = get_build_dir()
build_dir = get_build_dir("frontend", "frontend")
copilot_build_dir = get_build_dir(os.path.join("libs", "copilot"), "copilot")


app = FastAPI(lifespan=lifespan)
Expand All @@ -161,19 +163,29 @@ def get_build_dir():
name="assets",
)

app.mount(
"/copilot",
StaticFiles(
packages=[("chainlit", copilot_build_dir)],
follow_symlink=config.project.follow_symlink,
),
name="copilot",
)


app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_origins=config.project.allow_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)


socket = SocketManager(
app,
cors_allowed_origins=[],
cors_allowed_origins=[]
if config.project.allow_origins[0] == "*"
else config.project.allow_origins,
async_mode="asgi",
)

Expand All @@ -183,6 +195,11 @@ def get_build_dir():
# -------------------------------------------------------------------------------


def replace_between_tags(text: str, start_tag: str, end_tag: str, replacement: str):
pattern = start_tag + ".*?" + end_tag
return re.sub(pattern, start_tag + replacement + end_tag, text, flags=re.DOTALL)


def get_html_template():
PLACEHOLDER = "<!-- TAG INJECTION PLACEHOLDER -->"
JS_PLACEHOLDER = "<!-- JS INJECTION PLACEHOLDER -->"
Expand All @@ -207,6 +224,10 @@ def get_html_template():
f"""<link rel="stylesheet" type="text/css" href="{config.ui.custom_css}">"""
)

font = None
if config.ui.custom_font:
font = f"""<link rel="stylesheet" href="{config.ui.custom_font}">"""

index_html_file_path = os.path.join(build_dir, "index.html")

with open(index_html_file_path, "r", encoding="utf-8") as f:
Expand All @@ -216,6 +237,10 @@ def get_html_template():
content = content.replace(JS_PLACEHOLDER, js)
if css:
content = content.replace(CSS_PLACEHOLDER, css)
if font:
content = replace_between_tags(
content, "<!-- FONT START -->", "<!-- FONT END -->", font
)
return content


Expand Down
Loading

0 comments on commit 9c914c9

Please sign in to comment.