Skip to content

Commit

Permalink
chore: added concurrent threads count parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
lapsule committed Nov 19, 2024
1 parent 5dcfeae commit 241af17
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions docetl/operations/code_operations.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import Any, Dict, List, Optional, Tuple
from concurrent.futures import ThreadPoolExecutor
from docetl.operations.base import BaseOperation
Expand All @@ -7,6 +8,7 @@ class CodeMapOperation(BaseOperation):
class schema(BaseOperation.schema):
type: str = "code_map"
code: str
concurrent_thread_count: int = os.cpu_count()
drop_keys: Optional[List[str]] = None

def syntax_check(self) -> None:
Expand All @@ -27,7 +29,7 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:
transform_fn = namespace["transform"]

results = []
with ThreadPoolExecutor() as executor:
with ThreadPoolExecutor(max_workers=self.config.get('concurrent_thread_count', os.cpu_count())) as executor:
futures = [executor.submit(transform_fn, doc) for doc in input_data]
pbar = RichLoopBar(
range(len(futures)),
Expand All @@ -51,6 +53,7 @@ class CodeReduceOperation(BaseOperation):
class schema(BaseOperation.schema):
type: str = "code_reduce"
code: str
concurrent_thread_count: int = os.cpu_count()

def syntax_check(self) -> None:
config = self.schema(**self.config)
Expand Down Expand Up @@ -89,7 +92,7 @@ def get_group_key(item):
grouped_data = list(grouped_data.items())

results = []
with ThreadPoolExecutor() as executor:
with ThreadPoolExecutor(max_workers=self.config.get('concurrent_thread_count', os.cpu_count())) as executor:
futures = [executor.submit(reduce_fn, group) for _, group in grouped_data]
pbar = RichLoopBar(
range(len(futures)),
Expand All @@ -113,6 +116,7 @@ class CodeFilterOperation(BaseOperation):
class schema(BaseOperation.schema):
type: str = "code_filter"
code: str
concurrent_thread_count: int = os.cpu_count()

def syntax_check(self) -> None:
config = self.schema(**self.config)
Expand All @@ -132,7 +136,7 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:
filter_fn = namespace["transform"]

results = []
with ThreadPoolExecutor() as executor:
with ThreadPoolExecutor(max_workers=self.config.get('concurrent_thread_count', os.cpu_count())) as executor:
futures = [executor.submit(filter_fn, doc) for doc in input_data]
pbar = RichLoopBar(
range(len(futures)),
Expand Down

0 comments on commit 241af17

Please sign in to comment.