Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow optimizer model to be any litellm supported model #307

Merged
merged 2 commits into from
Feb 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion docetl/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def optimize(
model: str = "gpt-4o",
resume: bool = False,
timeout: int = 60,
litellm_kwargs: Dict[str, Any] = {},
) -> "Pipeline":
"""
Optimize the pipeline using the Optimizer.
Expand All @@ -203,7 +204,13 @@ def optimize(
yaml_file_suffix=self.name,
max_threads=max_threads,
)
optimized_config, _ = runner.optimize(return_pipeline=False)
optimized_config, _ = runner.optimize(
model=model,
resume=resume,
timeout=timeout,
litellm_kwargs=litellm_kwargs,
return_pipeline=False,
)

updated_pipeline = Pipeline(
name=self.name,
Expand Down
8 changes: 7 additions & 1 deletion docetl/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def build(
):
"""
Build and optimize the configuration specified in the YAML file.
Any arguments passed here will override the values in the YAML file.

Args:
yaml_file (Path): Path to the YAML file containing the pipeline configuration.
Expand All @@ -47,7 +48,12 @@ def build(

runner = DSLRunner.from_yaml(str(yaml_file), max_threads=max_threads)
runner.optimize(
save=True, return_pipeline=False, model=model, resume=resume, timeout=timeout
save=True,
return_pipeline=False,
model=model or runner.config.get("optimizer_model", "gpt-4o"),
resume=resume,
timeout=timeout,
litellm_kwargs=runner.config.get("optimizer_litellm_kwargs", {}),
)


Expand Down
1 change: 1 addition & 0 deletions docetl/config_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def __init__(
)
bucket_factory = BucketCollection(**buckets)
self.rate_limiter = pyrate_limiter.Limiter(bucket_factory, max_delay=math.inf)
self.is_cancelled = False

self.api = APIWrapper(self)

Expand Down
3 changes: 3 additions & 0 deletions docetl/operations/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
The `MapOperation` and `ParallelMapOperation` classes are subclasses of `BaseOperation` that perform mapping operations on input data. They use LLM-based processing to transform input items into output items based on specified prompts and schemas, and can also perform key dropping operations.
"""

import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -190,6 +191,8 @@ def validation_fn(response: Union[Dict[str, Any], ModelResponse]):
return output, False

self.runner.rate_limiter.try_acquire("call", weight=1)
if self.runner.is_cancelled:
raise asyncio.CancelledError("Operation was cancelled")
llm_result = self.runner.api.call_llm(
self.config.get("model", self.default_model),
"map",
Expand Down
6 changes: 6 additions & 0 deletions docetl/operations/utils/api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import ast
import asyncio
import hashlib
import json
import re
Expand Down Expand Up @@ -71,6 +72,8 @@ def gen_embedding(self, model: str, input: List[str]) -> List[float]:

# FIXME: Should we use a different limit for embedding?
self.runner.rate_limiter.try_acquire("embedding_call", weight=1)
if self.runner.is_cancelled:
raise asyncio.CancelledError("Operation was cancelled")
result = embedding(model=model, input=input)
# Cache the result
c.set(key, result)
Expand Down Expand Up @@ -582,6 +585,9 @@ def _call_llm_with_cache(
messages = truncate_messages(messages, model)

self.runner.rate_limiter.try_acquire("llm_call", weight=1)
if self.runner.is_cancelled:
raise asyncio.CancelledError("Operation was cancelled")

if tools is not None:
try:
response = completion(
Expand Down
3 changes: 2 additions & 1 deletion docetl/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
self,
runner: "DSLRunner",
model: str = "gpt-4o",
litellm_kwargs: Dict[str, Any] = {},
resume: bool = False,
timeout: int = 60,
):
Expand Down Expand Up @@ -98,7 +99,7 @@ def __init__(
self.status = runner.status

self.optimized_config = copy.deepcopy(self.config)
self.llm_client = LLMClient(model)
self.llm_client = LLMClient(runner, model, **litellm_kwargs)
self.timeout = timeout
self.resume = resume
self.captured_output = CapturedOutput()
Expand Down
6 changes: 5 additions & 1 deletion docetl/optimizers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,20 @@ class LLMClient:
and keeps track of the total cost of API calls.
"""

def __init__(self, model: str = "gpt-4o"):
def __init__(self, runner, model: str = "gpt-4o", **litellm_kwargs):
"""
Initialize the LLMClient.

Args:
model (str, optional): The name of the LLM model to use. Defaults to "gpt-4o".
**litellm_kwargs: Additional keyword arguments for the LLM model.
"""
if model == "gpt-4o":
model = "gpt-4o-2024-08-06"
self.model = model
self.litellm_kwargs = litellm_kwargs
self.total_cost = 0
self.runner = runner

def generate(
self,
Expand Down Expand Up @@ -59,6 +62,7 @@ def generate(
},
*messages,
],
**self.litellm_kwargs,
response_format={
"type": "json_schema",
"json_schema": {
Expand Down
13 changes: 13 additions & 0 deletions docetl/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,13 @@ def should_optimize(
self, step_name: str, op_name: str, **kwargs
) -> Tuple[str, float, List[Dict[str, Any]], List[Dict[str, Any]]]:
self.load()

# Augment the kwargs with the runner's config if not already provided
if "optimizer_litellm_kwargs" not in kwargs:
kwargs["litellm_kwargs"] = self.config.get("optimizer_litellm_kwargs", {})
if "optimizer_model" not in kwargs:
kwargs["model"] = self.config.get("optimizer_model", "gpt-4o")

builder = Optimizer(self, **kwargs)
self.optimizer = builder
result = builder.should_optimize(step_name, op_name)
Expand All @@ -639,6 +646,12 @@ def optimize(

self.load()

# Augment the kwargs with the runner's config if not already provided
if "optimizer_litellm_kwargs" not in kwargs:
kwargs["litellm_kwargs"] = self.config.get("optimizer_litellm_kwargs", {})
if "optimizer_model" not in kwargs:
kwargs["model"] = self.config.get("optimizer_model", "gpt-4o")

builder = Optimizer(
self,
**kwargs,
Expand Down
14 changes: 13 additions & 1 deletion server/app/routes/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ async def run_optimization(task_id: str, yaml_config: str, step_name: str, op_na
should_optimize, input_data, output_data, cost = await asyncio.to_thread(
runner.should_optimize,
step_name,
op_name
op_name,
model=runner.config.get("optimizer_model", "gpt-4o"),
litellm_kwargs=runner.config.get("optimizer_litellm_kwargs", {})
)

# Update task result
Expand All @@ -69,6 +71,7 @@ async def run_optimization(task_id: str, yaml_config: str, step_name: str, op_na
tasks[task_id].completed_at = datetime.now()

except asyncio.CancelledError:
runner.is_cancelled = True
tasks[task_id].status = TaskStatus.CANCELLED
tasks[task_id].completed_at = datetime.now()
raise
Expand Down Expand Up @@ -213,6 +216,9 @@ async def run_pipeline():

if user_message == "kill":
runner.console.log("Stopping process...")
runner.is_cancelled = True


await websocket.send_json({
"type": "error",
"message": "Process stopped by user request"
Expand All @@ -223,6 +229,12 @@ async def run_pipeline():
runner.console.post_input(user_message)
except asyncio.TimeoutError:
pass # No message received, continue with the loop
except asyncio.CancelledError:
await websocket.send_json({
"type": "error",
"message": "Process stopped by user request"
})
raise

await asyncio.sleep(0.5)

Expand Down
4 changes: 3 additions & 1 deletion website/src/app/api/getPipelineConfig/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ export async function POST(request: Request) {
sample_size,
namespace,
system_prompt,
optimizerModel,
} = await request.json();

if (!name) {
Expand Down Expand Up @@ -44,7 +45,8 @@ export async function POST(request: Request) {
system_prompt,
[],
"",
false
false,
optimizerModel
);

return NextResponse.json({ pipelineConfig: yamlString });
Expand Down
18 changes: 13 additions & 5 deletions website/src/app/api/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ export function generatePipelineConfig(
} | null = null,
apiKeys: APIKey[] = [],
docetl_encryption_key: string = "",
enable_observability: boolean = true
enable_observability: boolean = true,
optimizerModel: string = "gpt-4o"
) {
const datasets = {
input: {
Expand Down Expand Up @@ -203,13 +204,18 @@ export function generatePipelineConfig(
}

// Fetch all operations up until and including the operation_id
const operationsToRun = operations.slice(
0,
operations.findIndex((op: Operation) => op.id === operation_id) + 1
);
const operationsToRun = operations
.slice(
0,
operations.findIndex((op: Operation) => op.id === operation_id) + 1
)
.filter((op) =>
updatedOperations.some((updatedOp) => updatedOp.name === op.name)
);

// Fix type errors by asserting the pipeline config type
const pipelineConfig: any = {
optimizer_model: optimizerModel,
datasets,
default_model,
...(enable_observability && {
Expand Down Expand Up @@ -297,6 +303,8 @@ export function generatePipelineConfig(

const yamlString = yaml.dump(pipelineConfig);

console.log(yamlString);

return {
yamlString,
inputPath,
Expand Down
4 changes: 3 additions & 1 deletion website/src/app/api/writePipelineConfig/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ export async function POST(request: Request) {
system_prompt,
namespace,
apiKeys,
optimizerModel,
} = await request.json();

if (!name) {
Expand Down Expand Up @@ -55,7 +56,8 @@ export async function POST(request: Request) {
system_prompt,
apiKeys,
docetl_encryption_key,
true
true,
optimizerModel
);

// Use the FastAPI endpoint to write the pipeline config
Expand Down
2 changes: 1 addition & 1 deletion website/src/components/OperationCard.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,7 @@ export const OperationCard: React.FC<Props> = ({ index, id }) => {
system_prompt: systemPrompt,
namespace: namespace,
apiKeys: apiKeys,
optimizerModel: optimizerModel,
}),
});

Expand All @@ -922,7 +923,6 @@ export const OperationCard: React.FC<Props> = ({ index, id }) => {
sendMessage({
yaml_config: filePath,
optimize: true,
optimizer_model: optimizerModel,
});
} catch (error) {
console.error("Error optimizing operation:", error);
Expand Down
Loading