Skip to content

Commit

Permalink
refactor: dslrunner is now a pull based execution model
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyashankar committed Jan 9, 2025
1 parent fb26cdf commit 186c827
Show file tree
Hide file tree
Showing 36 changed files with 1,244 additions and 813 deletions.
19 changes: 10 additions & 9 deletions docetl/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,32 +49,32 @@
result = optimized_pipeline.run()
"""

import os
import inspect
from typing import Any, Dict, List, Optional, Callable, Union
import os
from typing import Any, Callable, Dict, List, Optional, Union

import yaml
from rich import print

from docetl.builder import Optimizer
from docetl.runner import DSLRunner
from docetl.schemas import (
ClusterOp,
Dataset,
EquijoinOp,
FilterOp,
GatherOp,
MapOp,
ReduceOp,
ResolveOp,
SplitOp,
UnnestOp,
ClusterOp,
SampleOp,
OpType,
ParallelMapOp,
ParsingTool,
PipelineOutput,
PipelineStep,
ReduceOp,
ResolveOp,
SampleOp,
SplitOp,
UnnestOp,
)


Expand Down Expand Up @@ -166,9 +166,10 @@ def __init__(
self._load_env()

def _load_env(self):
from dotenv import load_dotenv
import os

from dotenv import load_dotenv

# Get the current working directory
cwd = os.getcwd()

Expand Down
4 changes: 3 additions & 1 deletion docetl/base_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# from ..operations import map
# MapOp = map.MapOperation.schema


class ToolFunction(BaseModel):
name: str
description: str
Expand Down Expand Up @@ -96,7 +97,8 @@ class PipelineStep(BaseModel):
name: str
operations: List[Union[Dict[str, Any], str]]
input: Optional[str] = None



class PipelineOutput(BaseModel):
"""
Represents the output configuration for a pipeline.
Expand Down
189 changes: 124 additions & 65 deletions docetl/builder.py

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions docetl/cli.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import os
from pathlib import Path
from typing import Optional

import os
import typer
from dotenv import load_dotenv

from docetl.operations.utils import clear_cache as cc
from docetl.runner import DSLRunner

from dotenv import load_dotenv

app = typer.Typer()


@app.command()
def build(
yaml_file: Path = typer.Argument(
Expand Down
15 changes: 8 additions & 7 deletions docetl/config_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import datetime
import math
import os
from docetl.console import get_console
from docetl.utils import decrypt, load_config
from inspect import isawaitable
from typing import Any, Dict, List, Optional, Tuple, Union
from docetl.operations.utils import APIWrapper

import pyrate_limiter
from inspect import isawaitable
import math
from rich.console import Console

from docetl.console import get_console
from docetl.operations.utils import APIWrapper
from docetl.utils import decrypt, load_config


class BucketCollection(pyrate_limiter.BucketFactory):
def __init__(self, **buckets):
Expand Down Expand Up @@ -108,7 +110,6 @@ def __init__(
self.rate_limiter = pyrate_limiter.Limiter(bucket_factory, max_delay=math.inf)

self.api = APIWrapper(self)

def reset_env(self):
os.environ = self._original_env

23 changes: 16 additions & 7 deletions docetl/console.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import os
import queue
import threading
import time
from io import StringIO
from typing import Any, Optional, Tuple

from rich.console import Console
from io import StringIO
import threading
import queue

from docetl.utils import StageType, get_stage_description


class ThreadSafeConsole(Console):
def __init__(self, *args, **kwargs):
self.buffer = StringIO()
Expand Down Expand Up @@ -37,8 +40,10 @@ def status(
refresh_per_second=refresh_per_second,
)
return status_renderable

def post_optimizer_rationale(self, should_optimize: bool, rationale: str, validator_prompt: str):

def post_optimizer_rationale(
self, should_optimize: bool, rationale: str, validator_prompt: str
):
self.optimizer_rationale = (should_optimize, rationale, validator_prompt)

def post_optimizer_status(self, stage: StageType):
Expand All @@ -47,8 +52,11 @@ def post_optimizer_status(self, stage: StageType):
def get_optimizer_progress(self) -> Tuple[str, float]:
if len(self.optimizer_statuses) == 0:
return ("Optimization starting...", 0)

if len(self.optimizer_statuses) > 0 and self.optimizer_statuses[-1][0] == StageType.END:

if (
len(self.optimizer_statuses) > 0
and self.optimizer_statuses[-1][0] == StageType.END
):
return (get_stage_description(StageType.END), 1)

num_stages = len(StageType) - 1
Expand Down Expand Up @@ -89,6 +97,7 @@ def get_console():
highlight=False,
)
else:

class NoOpConsole(Console):
def post_optimizer_status(self, *args, **kwargs):
pass
Expand Down
7 changes: 4 additions & 3 deletions docetl/dataset.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Callable, Dict, List, Optional, Union, Literal
from typing import Any, Callable, Dict, List, Literal, Optional, Union

from pydantic import BaseModel

from docetl.parsing_tools import get_parser, get_parsing_tools
from docetl.base_schemas import ParsingTool
from docetl.parsing_tools import get_parser, get_parsing_tools


def create_parsing_tool_map(
Expand Down Expand Up @@ -73,7 +74,7 @@ class schema(BaseModel):
path: str
source: str = "local"
parsing: Optional[List[Dict[str, str]]] = None

def __init__(
self,
runner,
Expand Down
9 changes: 5 additions & 4 deletions docetl/operations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
from abc import ABC, ABCMeta, abstractmethod
from typing import Dict, List, Optional, Tuple

from docetl.operations.utils import APIWrapper
from docetl.console import DOCETL_CONSOLE
from rich.console import Console
from rich.status import Status
import jsonschema
from pydantic import BaseModel
from rich.console import Console
from rich.status import Status

from docetl.console import DOCETL_CONSOLE
from docetl.operations.utils import APIWrapper


# FIXME: This should probably live in some utils module?
Expand Down
64 changes: 42 additions & 22 deletions docetl/operations/cluster.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import numpy as np
from jinja2 import Environment, Template
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
from jinja2 import Environment, Template

from .base import BaseOperation
from .utils import RichLoopBar, strict_render
from .clustering_utils import get_embeddings_for_clustering

from .utils import RichLoopBar, strict_render


class ClusterOperation(BaseOperation):
Expand Down Expand Up @@ -101,9 +102,9 @@ def execute(
)

tree = self.agglomerative_cluster_of_embeddings(input_data, embeddings)

if "collapse" in self.config:
tree = self.collapse_tree(tree, collapse = self.config["collapse"])
tree = self.collapse_tree(tree, collapse=self.config["collapse"])

self.prompt_template = Template(self.config["summary_prompt"])
cost += self.annotate_clustering_tree(tree)
Expand All @@ -127,7 +128,7 @@ def build_tree(i):
# res["embedding"] = list(embeddings[i])
return res
return {
"children": [
"children": [
build_tree(cl.children_[i - nsamples, 0]),
build_tree(cl.children_[i - nsamples, 1]),
],
Expand All @@ -139,37 +140,54 @@ def build_tree(i):
def get_tree_distances(self, t):
res = set()
if "distance" in t:
res.update(set([t["distance"] - child["distance"] for child in t["children"] if "distance" in child]))
res.update(
set(
[
t["distance"] - child["distance"]
for child in t["children"]
if "distance" in child
]
)
)
if "children" in t:
for child in t["children"]:
res.update(self.get_tree_distances(child))
return res
def _collapse_tree(self, t, parent_dist = None, collapse = None):

def _collapse_tree(self, t, parent_dist=None, collapse=None):
if "children" in t:
if ( "distance" in t
if (
"distance" in t
and parent_dist is not None
and collapse is not None
and parent_dist - t["distance"] < collapse):
return [grandchild
for child in t["children"]
for grandchild in self._collapse_tree(child, parent_dist=parent_dist, collapse=collapse)]
and parent_dist - t["distance"] < collapse
):
return [
grandchild
for child in t["children"]
for grandchild in self._collapse_tree(
child, parent_dist=parent_dist, collapse=collapse
)
]
else:
res = dict(t)
res["children"] = [grandchild
for idx, child in enumerate(t["children"])
for grandchild in self._collapse_tree(child, parent_dist=t["distance"], collapse=collapse)]
res["children"] = [
grandchild
for idx, child in enumerate(t["children"])
for grandchild in self._collapse_tree(
child, parent_dist=t["distance"], collapse=collapse
)
]
return [res]
else:
return [t]
def collapse_tree(self, tree, collapse = None):

def collapse_tree(self, tree, collapse=None):
if collapse is not None:
tree_distances = np.array(sorted(self.get_tree_distances(tree)))
collapse = tree_distances[int(len(tree_distances) * collapse)]
return self._collapse_tree(tree, collapse=collapse)[0]


def annotate_clustering_tree(self, t):
if "children" in t:
with ThreadPoolExecutor(max_workers=self.max_batch_size) as executor:
Expand Down Expand Up @@ -218,7 +236,9 @@ def validation_fn(response: Dict[str, Any]):
else None
),
verbose=self.config.get("verbose", False),
litellm_completion_kwargs=self.config.get("litellm_completion_kwargs", {}),
litellm_completion_kwargs=self.config.get(
"litellm_completion_kwargs", {}
),
)
total_cost += response.total_cost
if response.validated:
Expand Down
Loading

0 comments on commit 186c827

Please sign in to comment.