Skip to content

Commit

Permalink
replacing multiprocessing pool with pipe (#491)
Browse files Browse the repository at this point in the history
* replacing multiprocessing pool with pipe

* code styling fix

* dropping obsolete chunk_size config parameter
  • Loading branch information
Szasza authored Mar 24, 2024
1 parent 1e565d9 commit a915385
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 40 deletions.
2 changes: 0 additions & 2 deletions docs/source/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,6 @@ The full set of configuration options are:
- `log_file` - str: Write log messages to a file at this path
- `n_procs` - int: Number of process to run in parallel when
parsing in CLI mode (Default: `1`)
- `chunk_size` - int: Number of files to give to each process
when running in parallel.

:::{note}
Setting this to a number larger than one can improve
Expand Down
85 changes: 47 additions & 38 deletions parsedmarc/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,12 @@
from configparser import ConfigParser
from glob import glob
import logging
import math
from collections import OrderedDict
import json
from ssl import CERT_NONE, create_default_context
from multiprocessing import Pool, Value
from itertools import repeat
from multiprocessing import Pipe, Process
import sys
import time
from tqdm import tqdm

from parsedmarc import get_dmarc_reports_from_mailbox, watch_inbox, \
Expand Down Expand Up @@ -42,7 +41,7 @@ def _str_to_list(s):


def cli_parse(file_path, sa, nameservers, dns_timeout,
ip_db_path, offline, parallel=False):
ip_db_path, offline, conn, parallel=False):
"""Separated this function for multiprocessing"""
try:
file_results = parse_report_file(file_path,
Expand All @@ -52,18 +51,11 @@ def cli_parse(file_path, sa, nameservers, dns_timeout,
dns_timeout=dns_timeout,
strip_attachment_payloads=sa,
parallel=parallel)
conn.send([file_results, file_path])
except ParserError as error:
return error, file_path
conn.send([error, file_path])
finally:
global counter
with counter.get_lock():
counter.value += 1
return file_results, file_path


def init(ctr):
global counter
counter = ctr
conn.close()


def _main():
Expand Down Expand Up @@ -481,7 +473,6 @@ def process_reports(reports_):
gmail_api_oauth2_port=8080,
log_file=args.log_file,
n_procs=1,
chunk_size=1,
ip_db_path=None,
la_client_id=None,
la_client_secret=None,
Expand Down Expand Up @@ -551,8 +542,6 @@ def process_reports(reports_):
opts.log_file = general_config["log_file"]
if "n_procs" in general_config:
opts.n_procs = general_config.getint("n_procs")
if "chunk_size" in general_config:
opts.chunk_size = general_config.getint("chunk_size")
if "ip_db_path" in general_config:
opts.ip_db_path = general_config["ip_db_path"]
else:
Expand Down Expand Up @@ -1144,29 +1133,49 @@ def process_reports(reports_):
for mbox_path in mbox_paths:
file_paths.remove(mbox_path)

counter = Value('i', 0)
pool = Pool(opts.n_procs, initializer=init, initargs=(counter,))
results = pool.starmap_async(cli_parse,
zip(file_paths,
repeat(opts.strip_attachment_payloads),
repeat(opts.nameservers),
repeat(opts.dns_timeout),
repeat(opts.ip_db_path),
repeat(opts.offline),
repeat(opts.n_procs >= 1)),
opts.chunk_size)
counter = 0

results = []

if sys.stdout.isatty():
pbar = tqdm(total=len(file_paths))
while not results.ready():
pbar.update(counter.value - pbar.n)
time.sleep(0.1)
pbar.close()
else:
while not results.ready():
time.sleep(0.1)
results = results.get()
pool.close()
pool.join()

for batch_index in range(math.ceil(len(file_paths) / opts.n_procs)):
processes = []
connections = []

for proc_index in range(
opts.n_procs * batch_index,
opts.n_procs * (batch_index + 1)):
if proc_index >= len(file_paths):
break

parent_conn, child_conn = Pipe()
connections.append(parent_conn)

process = Process(target=cli_parse, args=(
file_paths[proc_index],
opts.strip_attachment_payloads,
opts.nameservers,
opts.dns_timeout,
opts.ip_db_path,
opts.offline,
child_conn,
opts.n_procs >= 1,
))
processes.append(process)

for proc in processes:
proc.start()

for proc in processes:
proc.join()
if sys.stdout.isatty():
counter += 1
pbar.update(counter - pbar.n)

for conn in connections:
results.append(conn.recv())

for result in results:
if type(result[0]) is ParserError:
Expand Down

0 comments on commit a915385

Please sign in to comment.