Skip to content

Commit

Permalink
Merge pull request #134 from treasure-data/parallel-upload
Browse files Browse the repository at this point in the history
Add parallel upload for BulkImportWriter
  • Loading branch information
chezou authored Aug 29, 2024
2 parents 55f7aef + 04c5500 commit b560475
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 55 deletions.
76 changes: 34 additions & 42 deletions pytd/tests/test_writer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import io
import os
import tempfile
import unittest
from unittest.mock import ANY, MagicMock
from unittest.mock import ANY, MagicMock, patch

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -89,9 +89,6 @@ def test_cast_dtypes(self):
# This is for consistency of _get_schema
self.assertTrue(pd.isna(dft["O"][2]))

@unittest.skipIf(
pd.__version__ < "1.0.0", "pd.NA is not supported in this pandas version"
)
def test_cast_dtypes_nullable(self):
dft = pd.DataFrame(
{
Expand Down Expand Up @@ -264,7 +261,7 @@ def test_write_dataframe_tempfile_deletion(self):
# file pointer to a temp CSV file
fp = self.writer._bulk_import.call_args[0][1]
# temp file should not exist
self.assertFalse(os.path.isfile(fp.name))
self.assertFalse(os.path.isfile(fp[0].name))

# Case #2: bulk import failed
self.writer._bulk_import = MagicMock(side_effect=Exception())
Expand All @@ -273,22 +270,22 @@ def test_write_dataframe_tempfile_deletion(self):
pd.DataFrame([[1, 2], [3, 4]]), self.table, "overwrite"
)
fp = self.writer._bulk_import.call_args[0][1]
self.assertFalse(os.path.isfile(fp.name))
self.assertFalse(os.path.isfile(fp[0].name))

def test_write_dataframe_msgpack(self):
df = pd.DataFrame([[1, 2], [3, 4]])
self.writer.write_dataframe(df, self.table, "overwrite", fmt="msgpack")
api_client = self.table.client.api_client
self.assertTrue(api_client.create_bulk_import.called)
self.assertTrue(api_client.create_bulk_import().upload_part.called)
_bytes = BulkImportWriter()._write_msgpack_stream(
df.to_dict(orient="records"), io.BytesIO()
)
size = _bytes.getbuffer().nbytes
api_client.create_bulk_import().upload_part.assert_called_with(
"part", ANY, size
)
self.assertFalse(api_client.create_bulk_import().upload_file.called)
with tempfile.NamedTemporaryFile(delete=False) as fp:
fp = BulkImportWriter()._write_msgpack_stream(
df.to_dict(orient="records"), fp
)
api_client.create_bulk_import().upload_part.assert_called_with(
"part-0", ANY, 62
)
self.assertFalse(api_client.create_bulk_import().upload_file.called)

def test_write_dataframe_msgpack_with_int_na(self):
# Although this conversion ensures pd.NA Int64 dtype to None,
Expand All @@ -305,15 +302,15 @@ def test_write_dataframe_msgpack_with_int_na(self):
{"a": 3, "b": 4, "c": 5, "time": 1234},
]
self.writer._write_msgpack_stream = MagicMock()
self.writer.write_dataframe(df, self.table, "overwrite", fmt="msgpack")
self.assertTrue(self.writer._write_msgpack_stream.called)
self.assertEqual(
self.writer._write_msgpack_stream.call_args[0][0], expected_list
)
with patch("pytd.writer.os.unlink"):
self.writer.write_dataframe(df, self.table, "overwrite", fmt="msgpack")
self.assertTrue(self.writer._write_msgpack_stream.called)
print(self.writer._write_msgpack_stream.call_args[0][0][0:2])
self.assertEqual(
self.writer._write_msgpack_stream.call_args[0][0][0:2],
expected_list,
)

@unittest.skipIf(
pd.__version__ < "1.0.0", "pd.NA not supported in this pandas version"
)
def test_write_dataframe_msgpack_with_string_na(self):
df = pd.DataFrame(
data=[{"a": "foo", "b": "bar"}, {"a": "buzz", "b": "buzz", "c": "alice"}],
Expand All @@ -325,15 +322,14 @@ def test_write_dataframe_msgpack_with_string_na(self):
{"a": "buzz", "b": "buzz", "c": "alice", "time": 1234},
]
self.writer._write_msgpack_stream = MagicMock()
self.writer.write_dataframe(df, self.table, "overwrite", fmt="msgpack")
self.assertTrue(self.writer._write_msgpack_stream.called)
self.assertEqual(
self.writer._write_msgpack_stream.call_args[0][0], expected_list
)
with patch("pytd.writer.os.unlink"):
self.writer.write_dataframe(df, self.table, "overwrite", fmt="msgpack")
self.assertTrue(self.writer._write_msgpack_stream.called)
self.assertEqual(
self.writer._write_msgpack_stream.call_args[0][0][0:2],
expected_list,
)

@unittest.skipIf(
pd.__version__ < "1.0.0", "pd.NA not supported in this pandas version"
)
def test_write_dataframe_msgpack_with_boolean_na(self):
df = pd.DataFrame(
data=[{"a": True, "b": False}, {"a": False, "b": True, "c": True}],
Expand All @@ -345,11 +341,13 @@ def test_write_dataframe_msgpack_with_boolean_na(self):
{"a": "false", "b": "true", "c": "true", "time": 1234},
]
self.writer._write_msgpack_stream = MagicMock()
self.writer.write_dataframe(df, self.table, "overwrite", fmt="msgpack")
self.assertTrue(self.writer._write_msgpack_stream.called)
self.assertEqual(
self.writer._write_msgpack_stream.call_args[0][0], expected_list
)
with patch("pytd.writer.os.unlink"):
self.writer.write_dataframe(df, self.table, "overwrite", fmt="msgpack")
self.assertTrue(self.writer._write_msgpack_stream.called)
self.assertEqual(
self.writer._write_msgpack_stream.call_args[0][0][0:2],
expected_list,
)

def test_write_dataframe_invalid_if_exists(self):
with self.assertRaises(ValueError):
Expand Down Expand Up @@ -408,9 +406,6 @@ def test_write_dataframe_with_int_na(self):
self.writer.td_spark.spark.createDataFrame.call_args[0][0], expected_df
)

@unittest.skipIf(
pd.__version__ < "1.0.0", "pd.NA is not supported in this pandas version"
)
def test_write_dataframe_with_string_na(self):
df = pd.DataFrame(
data=[{"a": "foo", "b": "bar"}, {"a": "buzz", "b": "buzz", "c": "alice"}],
Expand All @@ -423,9 +418,6 @@ def test_write_dataframe_with_string_na(self):
self.writer.td_spark.spark.createDataFrame.call_args[0][0], expected_df
)

@unittest.skipIf(
pd.__version__ < "1.0.0", "pd.NA is not supported in this pandas version"
)
def test_write_dataframe_with_boolean_na(self):
df = pd.DataFrame(
data=[{"a": True, "b": False}, {"a": False, "b": True, "c": True}],
Expand Down
77 changes: 64 additions & 13 deletions pytd/writer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import abc
import gzip
import io
import logging
import os
import tempfile
import time
import uuid
from concurrent.futures import ThreadPoolExecutor
from contextlib import ExitStack

import msgpack
Expand Down Expand Up @@ -312,7 +312,16 @@ class BulkImportWriter(Writer):
td-client-python's bulk importer.
"""

def write_dataframe(self, dataframe, table, if_exists, fmt="csv", keep_list=False):
def write_dataframe(
self,
dataframe,
table,
if_exists,
fmt="csv",
keep_list=False,
max_workers=5,
chunk_record_size=10_000,
):
"""Write a given DataFrame to a Treasure Data table.
This method internally converts a given :class:`pandas.DataFrame` into a
Expand Down Expand Up @@ -407,6 +416,14 @@ def write_dataframe(self, dataframe, table, if_exists, fmt="csv", keep_list=Fals
Or, you can use :func:`Client.load_table_from_dataframe` function as well.
>>> client.load_table_from_dataframe(df, "bulk_import", keep_list=True)
max_workers : int, optional, default: 5
The maximum number of threads that can be used to execute the given calls.
This is used only when ``fmt`` is ``msgpack``.
chunk_record_size : int, optional, default: 10_000
The number of records to be written in a single file. This is used only when
``fmt`` is ``msgpack``.
"""
if self.closed:
raise RuntimeError("this writer is already closed and no longer available")
Expand All @@ -424,26 +441,42 @@ def write_dataframe(self, dataframe, table, if_exists, fmt="csv", keep_list=Fals
_cast_dtypes(dataframe, keep_list=keep_list)

with ExitStack() as stack:
fps = []
if fmt == "csv":
fp = tempfile.NamedTemporaryFile(suffix=".csv", delete=False)
stack.callback(os.unlink, fp.name)
stack.callback(fp.close)
dataframe.to_csv(fp.name)
fps.append(fp)
elif fmt == "msgpack":
_replace_pd_na(dataframe)

fp = io.BytesIO()
fp = self._write_msgpack_stream(dataframe.to_dict(orient="records"), fp)
stack.callback(fp.close)
try:
for start in range(0, len(dataframe), chunk_record_size):
records = dataframe.iloc[
start : start + chunk_record_size
].to_dict(orient="records")
fp = tempfile.NamedTemporaryFile(
suffix=".msgpack.gz", delete=False
)
fp = self._write_msgpack_stream(records, fp)
fps.append(fp)
stack.callback(os.unlink, fp.name)
stack.callback(fp.close)
except OSError as e:
raise RuntimeError(
"failed to create a temporary file. "
"Larger chunk_record_size may mitigate the issue."
) from e
else:
raise ValueError(
f"unsupported format '{fmt}' for bulk import. "
"should be 'csv' or 'msgpack'"
)
self._bulk_import(table, fp, if_exists, fmt)
self._bulk_import(table, fps, if_exists, fmt, max_workers=max_workers)
stack.close()

def _bulk_import(self, table, file_like, if_exists, fmt="csv"):
def _bulk_import(self, table, file_likes, if_exists, fmt="csv", max_workers=5):
"""Write a specified CSV file to a Treasure Data table.
This method uploads the file to Treasure Data via bulk import API.
Expand All @@ -453,7 +486,7 @@ def _bulk_import(self, table, file_like, if_exists, fmt="csv"):
table : :class:`pytd.table.Table`
Target table.
file_like : File like object
file_likes : List of file like objects
Data in this file will be loaded to a target table.
if_exists : str, {'error', 'overwrite', 'append', 'ignore'}
Expand All @@ -466,6 +499,10 @@ def _bulk_import(self, table, file_like, if_exists, fmt="csv"):
fmt : str, optional, {'csv', 'msgpack'}, default: 'csv'
File format for bulk import. See also :func:`write_dataframe`
max_workers : int, optional, default: 5
The maximum number of threads that can be used to execute the given calls.
This is used only when ``fmt`` is ``msgpack``.
"""
params = None
if table.exists:
Expand All @@ -490,19 +527,31 @@ def _bulk_import(self, table, file_like, if_exists, fmt="csv"):
bulk_import = table.client.api_client.create_bulk_import(
session_name, table.database, table.table, params=params
)
s_time = time.time()
try:
logger.info(f"uploading data converted into a {fmt} file")
if fmt == "msgpack":
size = file_like.getbuffer().nbytes
# To skip API._prepare_file(), which recreate msgpack again.
bulk_import.upload_part("part", file_like, size)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
for i, fp in enumerate(file_likes):
fsize = fp.tell()
fp.seek(0)
executor.submit(
bulk_import.upload_part,
f"part-{i}",
fp,
fsize,
)
logger.debug(f"to upload {fp.name} to TD. File size: {fsize}B")
else:
bulk_import.upload_file("part", fmt, file_like)
fp = file_likes[0]
bulk_import.upload_file("part", fmt, fp)
bulk_import.freeze()
except Exception as e:
bulk_import.delete()
raise RuntimeError(f"failed to upload file: {e}")

logger.debug(f"uploaded data in {time.time() - s_time:.2f} sec")

logger.info("performing a bulk import job")
job = bulk_import.perform(wait=True)

Expand Down Expand Up @@ -546,7 +595,9 @@ def _write_msgpack_stream(self, items, stream):
mp = packer.pack(normalized_msgpack(item))
gz.write(mp)

stream.seek(0)
logger.debug(
f"created a msgpack file: {stream.name}. File size: {stream.tell()}"
)
return stream


Expand Down

0 comments on commit b560475

Please sign in to comment.