diff --git a/docetl/api.py b/docetl/api.py index 1a5c2028..c3d35acf 100644 --- a/docetl/api.py +++ b/docetl/api.py @@ -183,6 +183,7 @@ def optimize( model: str = "gpt-4o", resume: bool = False, timeout: int = 60, + litellm_kwargs: Dict[str, Any] = {}, ) -> "Pipeline": """ Optimize the pipeline using the Optimizer. @@ -203,7 +204,13 @@ def optimize( yaml_file_suffix=self.name, max_threads=max_threads, ) - optimized_config, _ = runner.optimize(return_pipeline=False) + optimized_config, _ = runner.optimize( + model=model, + resume=resume, + timeout=timeout, + litellm_kwargs=litellm_kwargs, + return_pipeline=False, + ) updated_pipeline = Pipeline( name=self.name, diff --git a/docetl/cli.py b/docetl/cli.py index 8619eb98..bda275c4 100644 --- a/docetl/cli.py +++ b/docetl/cli.py @@ -29,6 +29,7 @@ def build( ): """ Build and optimize the configuration specified in the YAML file. + Any arguments passed here will override the values in the YAML file. Args: yaml_file (Path): Path to the YAML file containing the pipeline configuration. @@ -47,7 +48,12 @@ def build( runner = DSLRunner.from_yaml(str(yaml_file), max_threads=max_threads) runner.optimize( - save=True, return_pipeline=False, model=model, resume=resume, timeout=timeout + save=True, + return_pipeline=False, + model=model or runner.config.get("optimizer_model", "gpt-4o"), + resume=resume, + timeout=timeout, + litellm_kwargs=runner.config.get("optimizer_litellm_kwargs", {}), ) diff --git a/docetl/config_wrapper.py b/docetl/config_wrapper.py index f60bcf1b..7bb406f4 100644 --- a/docetl/config_wrapper.py +++ b/docetl/config_wrapper.py @@ -109,6 +109,7 @@ def __init__( ) bucket_factory = BucketCollection(**buckets) self.rate_limiter = pyrate_limiter.Limiter(bucket_factory, max_delay=math.inf) + self.is_cancelled = False self.api = APIWrapper(self) diff --git a/docetl/operations/map.py b/docetl/operations/map.py index f3d6231f..648139bb 100644 --- a/docetl/operations/map.py +++ b/docetl/operations/map.py @@ -2,6 +2,7 @@ The `MapOperation` and `ParallelMapOperation` classes are subclasses of `BaseOperation` that perform mapping operations on input data. They use LLM-based processing to transform input items into output items based on specified prompts and schemas, and can also perform key dropping operations. """ +import asyncio from concurrent.futures import ThreadPoolExecutor from typing import Any, Dict, List, Optional, Tuple, Union @@ -190,6 +191,8 @@ def validation_fn(response: Union[Dict[str, Any], ModelResponse]): return output, False self.runner.rate_limiter.try_acquire("call", weight=1) + if self.runner.is_cancelled: + raise asyncio.CancelledError("Operation was cancelled") llm_result = self.runner.api.call_llm( self.config.get("model", self.default_model), "map", diff --git a/docetl/operations/utils/api.py b/docetl/operations/utils/api.py index ccef74e5..edf7f532 100644 --- a/docetl/operations/utils/api.py +++ b/docetl/operations/utils/api.py @@ -1,4 +1,5 @@ import ast +import asyncio import hashlib import json import re @@ -71,6 +72,8 @@ def gen_embedding(self, model: str, input: List[str]) -> List[float]: # FIXME: Should we use a different limit for embedding? self.runner.rate_limiter.try_acquire("embedding_call", weight=1) + if self.runner.is_cancelled: + raise asyncio.CancelledError("Operation was cancelled") result = embedding(model=model, input=input) # Cache the result c.set(key, result) @@ -582,6 +585,9 @@ def _call_llm_with_cache( messages = truncate_messages(messages, model) self.runner.rate_limiter.try_acquire("llm_call", weight=1) + if self.runner.is_cancelled: + raise asyncio.CancelledError("Operation was cancelled") + if tools is not None: try: response = completion( diff --git a/docetl/optimizer.py b/docetl/optimizer.py index 20f031de..04a747e0 100644 --- a/docetl/optimizer.py +++ b/docetl/optimizer.py @@ -56,6 +56,7 @@ def __init__( self, runner: "DSLRunner", model: str = "gpt-4o", + litellm_kwargs: Dict[str, Any] = {}, resume: bool = False, timeout: int = 60, ): @@ -98,7 +99,7 @@ def __init__( self.status = runner.status self.optimized_config = copy.deepcopy(self.config) - self.llm_client = LLMClient(model) + self.llm_client = LLMClient(runner, model, **litellm_kwargs) self.timeout = timeout self.resume = resume self.captured_output = CapturedOutput() diff --git a/docetl/optimizers/utils.py b/docetl/optimizers/utils.py index d0305afd..b7a40c46 100644 --- a/docetl/optimizers/utils.py +++ b/docetl/optimizers/utils.py @@ -14,17 +14,20 @@ class LLMClient: and keeps track of the total cost of API calls. """ - def __init__(self, model: str = "gpt-4o"): + def __init__(self, runner, model: str = "gpt-4o", **litellm_kwargs): """ Initialize the LLMClient. Args: model (str, optional): The name of the LLM model to use. Defaults to "gpt-4o". + **litellm_kwargs: Additional keyword arguments for the LLM model. """ if model == "gpt-4o": model = "gpt-4o-2024-08-06" self.model = model + self.litellm_kwargs = litellm_kwargs self.total_cost = 0 + self.runner = runner def generate( self, @@ -59,6 +62,7 @@ def generate( }, *messages, ], + **self.litellm_kwargs, response_format={ "type": "json_schema", "json_schema": { diff --git a/docetl/runner.py b/docetl/runner.py index e1bcc83c..3dda2c12 100644 --- a/docetl/runner.py +++ b/docetl/runner.py @@ -622,6 +622,13 @@ def should_optimize( self, step_name: str, op_name: str, **kwargs ) -> Tuple[str, float, List[Dict[str, Any]], List[Dict[str, Any]]]: self.load() + + # Augment the kwargs with the runner's config if not already provided + if "optimizer_litellm_kwargs" not in kwargs: + kwargs["litellm_kwargs"] = self.config.get("optimizer_litellm_kwargs", {}) + if "optimizer_model" not in kwargs: + kwargs["model"] = self.config.get("optimizer_model", "gpt-4o") + builder = Optimizer(self, **kwargs) self.optimizer = builder result = builder.should_optimize(step_name, op_name) @@ -639,6 +646,12 @@ def optimize( self.load() + # Augment the kwargs with the runner's config if not already provided + if "optimizer_litellm_kwargs" not in kwargs: + kwargs["litellm_kwargs"] = self.config.get("optimizer_litellm_kwargs", {}) + if "optimizer_model" not in kwargs: + kwargs["model"] = self.config.get("optimizer_model", "gpt-4o") + builder = Optimizer( self, **kwargs, diff --git a/server/app/routes/pipeline.py b/server/app/routes/pipeline.py index e09b6ef1..9773a81e 100644 --- a/server/app/routes/pipeline.py +++ b/server/app/routes/pipeline.py @@ -57,7 +57,9 @@ async def run_optimization(task_id: str, yaml_config: str, step_name: str, op_na should_optimize, input_data, output_data, cost = await asyncio.to_thread( runner.should_optimize, step_name, - op_name + op_name, + model=runner.config.get("optimizer_model", "gpt-4o"), + litellm_kwargs=runner.config.get("optimizer_litellm_kwargs", {}) ) # Update task result @@ -69,6 +71,7 @@ async def run_optimization(task_id: str, yaml_config: str, step_name: str, op_na tasks[task_id].completed_at = datetime.now() except asyncio.CancelledError: + runner.is_cancelled = True tasks[task_id].status = TaskStatus.CANCELLED tasks[task_id].completed_at = datetime.now() raise @@ -213,6 +216,9 @@ async def run_pipeline(): if user_message == "kill": runner.console.log("Stopping process...") + runner.is_cancelled = True + + await websocket.send_json({ "type": "error", "message": "Process stopped by user request" @@ -223,6 +229,12 @@ async def run_pipeline(): runner.console.post_input(user_message) except asyncio.TimeoutError: pass # No message received, continue with the loop + except asyncio.CancelledError: + await websocket.send_json({ + "type": "error", + "message": "Process stopped by user request" + }) + raise await asyncio.sleep(0.5) diff --git a/website/src/app/api/getPipelineConfig/route.ts b/website/src/app/api/getPipelineConfig/route.ts index 77e3d16c..c77e8f52 100644 --- a/website/src/app/api/getPipelineConfig/route.ts +++ b/website/src/app/api/getPipelineConfig/route.ts @@ -12,6 +12,7 @@ export async function POST(request: Request) { sample_size, namespace, system_prompt, + optimizerModel, } = await request.json(); if (!name) { @@ -44,7 +45,8 @@ export async function POST(request: Request) { system_prompt, [], "", - false + false, + optimizerModel ); return NextResponse.json({ pipelineConfig: yamlString }); diff --git a/website/src/app/api/utils.ts b/website/src/app/api/utils.ts index c90f9296..311e41cd 100644 --- a/website/src/app/api/utils.ts +++ b/website/src/app/api/utils.ts @@ -50,7 +50,8 @@ export function generatePipelineConfig( } | null = null, apiKeys: APIKey[] = [], docetl_encryption_key: string = "", - enable_observability: boolean = true + enable_observability: boolean = true, + optimizerModel: string = "gpt-4o" ) { const datasets = { input: { @@ -203,13 +204,18 @@ export function generatePipelineConfig( } // Fetch all operations up until and including the operation_id - const operationsToRun = operations.slice( - 0, - operations.findIndex((op: Operation) => op.id === operation_id) + 1 - ); + const operationsToRun = operations + .slice( + 0, + operations.findIndex((op: Operation) => op.id === operation_id) + 1 + ) + .filter((op) => + updatedOperations.some((updatedOp) => updatedOp.name === op.name) + ); // Fix type errors by asserting the pipeline config type const pipelineConfig: any = { + optimizer_model: optimizerModel, datasets, default_model, ...(enable_observability && { @@ -297,6 +303,8 @@ export function generatePipelineConfig( const yamlString = yaml.dump(pipelineConfig); + console.log(yamlString); + return { yamlString, inputPath, diff --git a/website/src/app/api/writePipelineConfig/route.ts b/website/src/app/api/writePipelineConfig/route.ts index f9905971..1d6baa8a 100644 --- a/website/src/app/api/writePipelineConfig/route.ts +++ b/website/src/app/api/writePipelineConfig/route.ts @@ -22,6 +22,7 @@ export async function POST(request: Request) { system_prompt, namespace, apiKeys, + optimizerModel, } = await request.json(); if (!name) { @@ -55,7 +56,8 @@ export async function POST(request: Request) { system_prompt, apiKeys, docetl_encryption_key, - true + true, + optimizerModel ); // Use the FastAPI endpoint to write the pipeline config diff --git a/website/src/components/OperationCard.tsx b/website/src/components/OperationCard.tsx index 84bf50e7..65c5173e 100644 --- a/website/src/components/OperationCard.tsx +++ b/website/src/components/OperationCard.tsx @@ -906,6 +906,7 @@ export const OperationCard: React.FC = ({ index, id }) => { system_prompt: systemPrompt, namespace: namespace, apiKeys: apiKeys, + optimizerModel: optimizerModel, }), }); @@ -922,7 +923,6 @@ export const OperationCard: React.FC = ({ index, id }) => { sendMessage({ yaml_config: filePath, optimize: true, - optimizer_model: optimizerModel, }); } catch (error) { console.error("Error optimizing operation:", error); diff --git a/website/src/components/PipelineGui.tsx b/website/src/components/PipelineGui.tsx index c8d61b51..74331623 100644 --- a/website/src/components/PipelineGui.tsx +++ b/website/src/components/PipelineGui.tsx @@ -248,6 +248,56 @@ const AddOperationDropdown: React.FC = ({ ); }; +const ModelInput: React.FC<{ + value: string; + onChange: (value: string) => void; + placeholder: string; + suggestions?: readonly string[]; +}> = ({ value, onChange, placeholder, suggestions = PREDEFINED_MODELS }) => { + const [isFocused, setIsFocused] = useState(false); + + return ( +
+ onChange(e.target.value)} + className="w-full" + placeholder={placeholder} + onFocus={() => setIsFocused(true)} + onBlur={() => { + setTimeout(() => setIsFocused(false), 200); + }} + /> + {isFocused && + (value === "" || + suggestions.some((model) => + model.toLowerCase().includes(value?.toLowerCase() || "") + )) && ( +
+ {suggestions + .filter( + (model) => + value === "" || + model.toLowerCase().includes(value.toLowerCase()) + ) + .map((model) => ( +
{ + onChange(model); + setIsFocused(false); + }} + > + {model} +
+ ))} +
+ )} +
+ ); +}; + const PipelineGUI: React.FC = () => { const fileInputRef = useRef(null); const headerRef = useRef(null); @@ -572,8 +622,9 @@ const PipelineGUI: React.FC = () => { operation_id: operations[operations.length - 1].id, name: pipelineName, sample_size: sampleSize, - namespace, + namespace: namespace, system_prompt: systemPrompt, + optimizerModel: optimizerModel, }), }); @@ -651,8 +702,9 @@ const PipelineGUI: React.FC = () => { sample_size: sampleSize, clear_intermediate: clear_intermediate, system_prompt: systemPrompt, - namespace, + namespace: namespace, apiKeys: currentApiKeys, // Use the latest API keys + optimizerModel: optimizerModel, }), }); @@ -755,8 +807,9 @@ const PipelineGUI: React.FC = () => { name: pipelineName, sample_size: sampleSize, optimize: true, - namespace, - apiKeys, + namespace: namespace, + apiKeys: apiKeys, + optimizerModel: optimizerModel, }), }); @@ -771,7 +824,6 @@ const PipelineGUI: React.FC = () => { sendMessage({ yaml_config: filePath, optimize: true, - optimizer_model: optimizerModel, }); } catch (error) { console.error("Error optimizing operation:", error); @@ -1188,52 +1240,15 @@ const PipelineGUI: React.FC = () => {
-
- setTempDefaultModel(e.target.value)} - className="w-full" - placeholder="Enter or select a model..." - onFocus={() => setIsModelInputFocused(true)} - onBlur={() => { - setTimeout(() => setIsModelInputFocused(false), 200); - }} - /> - {isModelInputFocused && - (tempDefaultModel === "" || - PREDEFINED_MODELS.some((model) => - model - .toLowerCase() - .includes(tempDefaultModel?.toLowerCase() || "") - )) && ( -
- {PREDEFINED_MODELS.filter( - (model) => - tempDefaultModel === "" || - model - .toLowerCase() - .includes(tempDefaultModel.toLowerCase()) - ).map((model) => ( -
{ - setTempDefaultModel(model); - setIsModelInputFocused(false); - }} - > - {model} -
- ))} -
- )} -
+

Enter any LiteLLM model name or select from suggestions. Make - sure you've set your API keys in Edit{" "} - {String.fromCharCode(8594)} Edit API Keys when using our hosted - app.{" "} + sure you've set your API keys in Edit {">"} Edit API Keys + when using our hosted app.{" "} {

) : ( - +
+ +

+ Enter any LiteLLM model name (e.g., + "azure/gpt-4o") or select from suggestions above. + Make sure the model supports JSON mode. +

+
)}