From 241af171a5a769d14637ed962a507ed9d00f2b44 Mon Sep 17 00:00:00 2001 From: Tim Date: Tue, 19 Nov 2024 10:57:56 +0800 Subject: [PATCH] chore: added concurrent threads count parameter --- docetl/operations/code_operations.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/docetl/operations/code_operations.py b/docetl/operations/code_operations.py index 09a62c9a..808b22fa 100644 --- a/docetl/operations/code_operations.py +++ b/docetl/operations/code_operations.py @@ -1,3 +1,4 @@ +import os from typing import Any, Dict, List, Optional, Tuple from concurrent.futures import ThreadPoolExecutor from docetl.operations.base import BaseOperation @@ -7,6 +8,7 @@ class CodeMapOperation(BaseOperation): class schema(BaseOperation.schema): type: str = "code_map" code: str + concurrent_thread_count: int = os.cpu_count() drop_keys: Optional[List[str]] = None def syntax_check(self) -> None: @@ -27,7 +29,7 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: transform_fn = namespace["transform"] results = [] - with ThreadPoolExecutor() as executor: + with ThreadPoolExecutor(max_workers=self.config.get('concurrent_thread_count', os.cpu_count())) as executor: futures = [executor.submit(transform_fn, doc) for doc in input_data] pbar = RichLoopBar( range(len(futures)), @@ -51,6 +53,7 @@ class CodeReduceOperation(BaseOperation): class schema(BaseOperation.schema): type: str = "code_reduce" code: str + concurrent_thread_count: int = os.cpu_count() def syntax_check(self) -> None: config = self.schema(**self.config) @@ -89,7 +92,7 @@ def get_group_key(item): grouped_data = list(grouped_data.items()) results = [] - with ThreadPoolExecutor() as executor: + with ThreadPoolExecutor(max_workers=self.config.get('concurrent_thread_count', os.cpu_count())) as executor: futures = [executor.submit(reduce_fn, group) for _, group in grouped_data] pbar = RichLoopBar( range(len(futures)), @@ -113,6 +116,7 @@ class CodeFilterOperation(BaseOperation): class schema(BaseOperation.schema): type: str = "code_filter" code: str + concurrent_thread_count: int = os.cpu_count() def syntax_check(self) -> None: config = self.schema(**self.config) @@ -132,7 +136,7 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: filter_fn = namespace["transform"] results = [] - with ThreadPoolExecutor() as executor: + with ThreadPoolExecutor(max_workers=self.config.get('concurrent_thread_count', os.cpu_count())) as executor: futures = [executor.submit(filter_fn, doc) for doc in input_data] pbar = RichLoopBar( range(len(futures)),