Skip to content

Commit

Permalink
feat: allow optimizer model to be any litellm supported model (#307)
Browse files Browse the repository at this point in the history
* feat: allow optimizer model to be any litellm supported model

* fix: stop operations
  • Loading branch information
shreyashankar authored Feb 9, 2025
1 parent d7eef7d commit be7add5
Show file tree
Hide file tree
Showing 14 changed files with 155 additions and 74 deletions.
9 changes: 8 additions & 1 deletion docetl/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def optimize(
model: str = "gpt-4o",
resume: bool = False,
timeout: int = 60,
litellm_kwargs: Dict[str, Any] = {},
) -> "Pipeline":
"""
Optimize the pipeline using the Optimizer.
Expand All @@ -203,7 +204,13 @@ def optimize(
yaml_file_suffix=self.name,
max_threads=max_threads,
)
optimized_config, _ = runner.optimize(return_pipeline=False)
optimized_config, _ = runner.optimize(
model=model,
resume=resume,
timeout=timeout,
litellm_kwargs=litellm_kwargs,
return_pipeline=False,
)

updated_pipeline = Pipeline(
name=self.name,
Expand Down
8 changes: 7 additions & 1 deletion docetl/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def build(
):
"""
Build and optimize the configuration specified in the YAML file.
Any arguments passed here will override the values in the YAML file.
Args:
yaml_file (Path): Path to the YAML file containing the pipeline configuration.
Expand All @@ -47,7 +48,12 @@ def build(

runner = DSLRunner.from_yaml(str(yaml_file), max_threads=max_threads)
runner.optimize(
save=True, return_pipeline=False, model=model, resume=resume, timeout=timeout
save=True,
return_pipeline=False,
model=model or runner.config.get("optimizer_model", "gpt-4o"),
resume=resume,
timeout=timeout,
litellm_kwargs=runner.config.get("optimizer_litellm_kwargs", {}),
)


Expand Down
1 change: 1 addition & 0 deletions docetl/config_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def __init__(
)
bucket_factory = BucketCollection(**buckets)
self.rate_limiter = pyrate_limiter.Limiter(bucket_factory, max_delay=math.inf)
self.is_cancelled = False

self.api = APIWrapper(self)

Expand Down
3 changes: 3 additions & 0 deletions docetl/operations/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
The `MapOperation` and `ParallelMapOperation` classes are subclasses of `BaseOperation` that perform mapping operations on input data. They use LLM-based processing to transform input items into output items based on specified prompts and schemas, and can also perform key dropping operations.
"""

import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -190,6 +191,8 @@ def validation_fn(response: Union[Dict[str, Any], ModelResponse]):
return output, False

self.runner.rate_limiter.try_acquire("call", weight=1)
if self.runner.is_cancelled:
raise asyncio.CancelledError("Operation was cancelled")
llm_result = self.runner.api.call_llm(
self.config.get("model", self.default_model),
"map",
Expand Down
6 changes: 6 additions & 0 deletions docetl/operations/utils/api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import ast
import asyncio
import hashlib
import json
import re
Expand Down Expand Up @@ -71,6 +72,8 @@ def gen_embedding(self, model: str, input: List[str]) -> List[float]:

# FIXME: Should we use a different limit for embedding?
self.runner.rate_limiter.try_acquire("embedding_call", weight=1)
if self.runner.is_cancelled:
raise asyncio.CancelledError("Operation was cancelled")
result = embedding(model=model, input=input)
# Cache the result
c.set(key, result)
Expand Down Expand Up @@ -582,6 +585,9 @@ def _call_llm_with_cache(
messages = truncate_messages(messages, model)

self.runner.rate_limiter.try_acquire("llm_call", weight=1)
if self.runner.is_cancelled:
raise asyncio.CancelledError("Operation was cancelled")

if tools is not None:
try:
response = completion(
Expand Down
3 changes: 2 additions & 1 deletion docetl/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
self,
runner: "DSLRunner",
model: str = "gpt-4o",
litellm_kwargs: Dict[str, Any] = {},
resume: bool = False,
timeout: int = 60,
):
Expand Down Expand Up @@ -98,7 +99,7 @@ def __init__(
self.status = runner.status

self.optimized_config = copy.deepcopy(self.config)
self.llm_client = LLMClient(model)
self.llm_client = LLMClient(runner, model, **litellm_kwargs)
self.timeout = timeout
self.resume = resume
self.captured_output = CapturedOutput()
Expand Down
6 changes: 5 additions & 1 deletion docetl/optimizers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,20 @@ class LLMClient:
and keeps track of the total cost of API calls.
"""

def __init__(self, model: str = "gpt-4o"):
def __init__(self, runner, model: str = "gpt-4o", **litellm_kwargs):
"""
Initialize the LLMClient.
Args:
model (str, optional): The name of the LLM model to use. Defaults to "gpt-4o".
**litellm_kwargs: Additional keyword arguments for the LLM model.
"""
if model == "gpt-4o":
model = "gpt-4o-2024-08-06"
self.model = model
self.litellm_kwargs = litellm_kwargs
self.total_cost = 0
self.runner = runner

def generate(
self,
Expand Down Expand Up @@ -59,6 +62,7 @@ def generate(
},
*messages,
],
**self.litellm_kwargs,
response_format={
"type": "json_schema",
"json_schema": {
Expand Down
13 changes: 13 additions & 0 deletions docetl/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,13 @@ def should_optimize(
self, step_name: str, op_name: str, **kwargs
) -> Tuple[str, float, List[Dict[str, Any]], List[Dict[str, Any]]]:
self.load()

# Augment the kwargs with the runner's config if not already provided
if "optimizer_litellm_kwargs" not in kwargs:
kwargs["litellm_kwargs"] = self.config.get("optimizer_litellm_kwargs", {})
if "optimizer_model" not in kwargs:
kwargs["model"] = self.config.get("optimizer_model", "gpt-4o")

builder = Optimizer(self, **kwargs)
self.optimizer = builder
result = builder.should_optimize(step_name, op_name)
Expand All @@ -639,6 +646,12 @@ def optimize(

self.load()

# Augment the kwargs with the runner's config if not already provided
if "optimizer_litellm_kwargs" not in kwargs:
kwargs["litellm_kwargs"] = self.config.get("optimizer_litellm_kwargs", {})
if "optimizer_model" not in kwargs:
kwargs["model"] = self.config.get("optimizer_model", "gpt-4o")

builder = Optimizer(
self,
**kwargs,
Expand Down
14 changes: 13 additions & 1 deletion server/app/routes/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ async def run_optimization(task_id: str, yaml_config: str, step_name: str, op_na
should_optimize, input_data, output_data, cost = await asyncio.to_thread(
runner.should_optimize,
step_name,
op_name
op_name,
model=runner.config.get("optimizer_model", "gpt-4o"),
litellm_kwargs=runner.config.get("optimizer_litellm_kwargs", {})
)

# Update task result
Expand All @@ -69,6 +71,7 @@ async def run_optimization(task_id: str, yaml_config: str, step_name: str, op_na
tasks[task_id].completed_at = datetime.now()

except asyncio.CancelledError:
runner.is_cancelled = True
tasks[task_id].status = TaskStatus.CANCELLED
tasks[task_id].completed_at = datetime.now()
raise
Expand Down Expand Up @@ -213,6 +216,9 @@ async def run_pipeline():

if user_message == "kill":
runner.console.log("Stopping process...")
runner.is_cancelled = True


await websocket.send_json({
"type": "error",
"message": "Process stopped by user request"
Expand All @@ -223,6 +229,12 @@ async def run_pipeline():
runner.console.post_input(user_message)
except asyncio.TimeoutError:
pass # No message received, continue with the loop
except asyncio.CancelledError:
await websocket.send_json({
"type": "error",
"message": "Process stopped by user request"
})
raise

await asyncio.sleep(0.5)

Expand Down
4 changes: 3 additions & 1 deletion website/src/app/api/getPipelineConfig/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ export async function POST(request: Request) {
sample_size,
namespace,
system_prompt,
optimizerModel,
} = await request.json();

if (!name) {
Expand Down Expand Up @@ -44,7 +45,8 @@ export async function POST(request: Request) {
system_prompt,
[],
"",
false
false,
optimizerModel
);

return NextResponse.json({ pipelineConfig: yamlString });
Expand Down
18 changes: 13 additions & 5 deletions website/src/app/api/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ export function generatePipelineConfig(
} | null = null,
apiKeys: APIKey[] = [],
docetl_encryption_key: string = "",
enable_observability: boolean = true
enable_observability: boolean = true,
optimizerModel: string = "gpt-4o"
) {
const datasets = {
input: {
Expand Down Expand Up @@ -203,13 +204,18 @@ export function generatePipelineConfig(
}

// Fetch all operations up until and including the operation_id
const operationsToRun = operations.slice(
0,
operations.findIndex((op: Operation) => op.id === operation_id) + 1
);
const operationsToRun = operations
.slice(
0,
operations.findIndex((op: Operation) => op.id === operation_id) + 1
)
.filter((op) =>
updatedOperations.some((updatedOp) => updatedOp.name === op.name)
);

// Fix type errors by asserting the pipeline config type
const pipelineConfig: any = {
optimizer_model: optimizerModel,
datasets,
default_model,
...(enable_observability && {
Expand Down Expand Up @@ -297,6 +303,8 @@ export function generatePipelineConfig(

const yamlString = yaml.dump(pipelineConfig);

console.log(yamlString);

return {
yamlString,
inputPath,
Expand Down
4 changes: 3 additions & 1 deletion website/src/app/api/writePipelineConfig/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ export async function POST(request: Request) {
system_prompt,
namespace,
apiKeys,
optimizerModel,
} = await request.json();

if (!name) {
Expand Down Expand Up @@ -55,7 +56,8 @@ export async function POST(request: Request) {
system_prompt,
apiKeys,
docetl_encryption_key,
true
true,
optimizerModel
);

// Use the FastAPI endpoint to write the pipeline config
Expand Down
2 changes: 1 addition & 1 deletion website/src/components/OperationCard.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,7 @@ export const OperationCard: React.FC<Props> = ({ index, id }) => {
system_prompt: systemPrompt,
namespace: namespace,
apiKeys: apiKeys,
optimizerModel: optimizerModel,
}),
});

Expand All @@ -922,7 +923,6 @@ export const OperationCard: React.FC<Props> = ({ index, id }) => {
sendMessage({
yaml_config: filePath,
optimize: true,
optimizer_model: optimizerModel,
});
} catch (error) {
console.error("Error optimizing operation:", error);
Expand Down
Loading

0 comments on commit be7add5

Please sign in to comment.