Skip to content

Commit

Permalink
Simplify the buffer builder for RPC client (#1688)
Browse files Browse the repository at this point in the history
Fixes #1664

Signed-off-by: Tao He <[email protected]>
  • Loading branch information
sighingnow authored Dec 21, 2023
1 parent c32dcbb commit 52ad1a9
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 87 deletions.
46 changes: 31 additions & 15 deletions python/core.cc
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,9 @@ void bind_blobs(py::module& mod) {
[](BlobWriter* self, size_t const offset, uintptr_t ptr,
size_t const size,
size_t const concurrency = memory::default_memcpy_concurrency) {
if (size == 0) {
return;
}
memory::concurrent_memcpy(self->data() + offset,
reinterpret_cast<void*>(ptr), size,
concurrency);
Expand All @@ -553,6 +556,9 @@ void bind_blobs(py::module& mod) {
"copy",
[](BlobWriter* self, size_t offset, py::buffer const& buffer,
size_t const concurrency = memory::default_memcpy_concurrency) {
if (self->size() == 0) {
return;
}
throw_on_error(copy_memoryview(buffer.ptr(), self->data(),
self->size(), offset, concurrency));
},
Expand All @@ -564,18 +570,20 @@ void bind_blobs(py::module& mod) {
[](BlobWriter* self, size_t offset, py::bytes const& bs,
size_t const concurrency = memory::default_memcpy_concurrency) {
char* buffer = nullptr;
ssize_t length = 0;
if (PYBIND11_BYTES_AS_STRING_AND_SIZE(bs.ptr(), &buffer, &length)) {
ssize_t size = 0;
if (PYBIND11_BYTES_AS_STRING_AND_SIZE(bs.ptr(), &buffer, &size)) {
py::pybind11_fail("Unable to extract bytes contents!");
}
if (offset + length > self->size()) {
if (size == 0) {
return;
}
if (offset + size > self->size()) {
throw_on_error(Status::AssertionFailed(
"Expect a source buffer with size at most '" +
std::to_string(self->size() - offset) +
"', but the buffer size is '" + std::to_string(length) +
"'"));
"', but the buffer size is '" + std::to_string(size) + "'"));
}
memory::concurrent_memcpy(self->data() + offset, buffer, length,
memory::concurrent_memcpy(self->data() + offset, buffer, size,
concurrency);
},
"offset"_a, "bytes"_a,
Expand Down Expand Up @@ -670,12 +678,12 @@ void bind_blobs(py::module& mod) {
"wrap",
[](py::bytes const& bs) -> std::shared_ptr<RemoteBlobWriter> {
char* buffer = nullptr;
ssize_t length = 0;
if (PYBIND11_BYTES_AS_STRING_AND_SIZE(bs.ptr(), &buffer, &length)) {
ssize_t size = 0;
if (PYBIND11_BYTES_AS_STRING_AND_SIZE(bs.ptr(), &buffer, &size)) {
py::pybind11_fail("Unable to extract bytes contents!");
}
return RemoteBlobWriter::Wrap(
reinterpret_cast<const uint8_t*>(buffer), length);
reinterpret_cast<const uint8_t*>(buffer), size);
},
"data"_a)
.def_property_readonly("size", &RemoteBlobWriter::size,
Expand Down Expand Up @@ -712,6 +720,9 @@ void bind_blobs(py::module& mod) {
[](RemoteBlobWriter* self, size_t const offset, uintptr_t ptr,
size_t const size,
size_t const concurrency = memory::default_memcpy_concurrency) {
if (size == 0) {
return;
}
memory::concurrent_memcpy(self->data() + offset,
reinterpret_cast<void*>(ptr), size,
concurrency);
Expand All @@ -723,6 +734,9 @@ void bind_blobs(py::module& mod) {
"copy",
[](RemoteBlobWriter* self, size_t offset, py::buffer const& buffer,
size_t const concurrency = memory::default_memcpy_concurrency) {
if (self->size() == 0) {
return;
}
throw_on_error(copy_memoryview(buffer.ptr(), self->data(),
self->size(), offset, concurrency));
},
Expand All @@ -734,18 +748,20 @@ void bind_blobs(py::module& mod) {
[](RemoteBlobWriter* self, size_t offset, py::bytes const& bs,
size_t const concurrency = memory::default_memcpy_concurrency) {
char* buffer = nullptr;
ssize_t length = 0;
if (PYBIND11_BYTES_AS_STRING_AND_SIZE(bs.ptr(), &buffer, &length)) {
ssize_t size = 0;
if (PYBIND11_BYTES_AS_STRING_AND_SIZE(bs.ptr(), &buffer, &size)) {
py::pybind11_fail("Unable to extract bytes contents!");
}
if (offset + length > self->size()) {
if (size == 0) {
return;
}
if (offset + size > self->size()) {
throw_on_error(Status::AssertionFailed(
"Expect a source buffer with size at most '" +
std::to_string(self->size() - offset) +
"', but the buffer size is '" + std::to_string(length) +
"'"));
"', but the buffer size is '" + std::to_string(size) + "'"));
}
memory::concurrent_memcpy(self->data() + offset, buffer, length,
memory::concurrent_memcpy(self->data() + offset, buffer, size,
concurrency);
},
"offset"_a, "bytes"_a,
Expand Down
49 changes: 18 additions & 31 deletions python/vineyard/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# limitations under the License.
#

import ctypes
import json
import pickle
from typing import Union
Expand Down Expand Up @@ -150,37 +149,25 @@ def build_buffer(
If address is None or size is 0, an empty blob will be returned.
'''
if client.is_rpc:
# copy the address with size to a local payloads
if size == 0 or address is None:
meta = ObjectMeta()
meta.nbytes = 0
meta.typename = "vineyard::RemoteBlob"
return client.create_metadata(meta)
if isinstance(address, bytes):
payload = address
else:
payload = bytearray(size)
address_bytes = (ctypes.c_byte * size).from_address(address)
payload[:size] = memoryview(address_bytes)[:size]
buffer = RemoteBlobBuilder(size)
buffer.copy(0, payload)
id = client.create_remote_blob(buffer)
meta = client.get_meta(id)
return meta

if size == 0:
return client.create_empty_blob()
if address is None:
return client.create_empty_blob()
existing = client.find_shared_memory(address)
if existing is not None:
return client.get_meta(existing)
buffer = client.create_blob(size)
if isinstance(address, (int, np.integer)):
buffer.copy(0, int(address), size)
else:
buffer.copy(0, address)
return buffer.seal(client)
if isinstance(address, (int, np.integer)):
buffer.copy(0, int(address), size)
elif address is not None:
buffer.copy(0, address)
return client.get_meta(client.create_remote_blob(buffer))

if client.is_ipc:
if size == 0 or address is None:
return client.create_empty_blob()
existing = client.find_shared_memory(address)
if existing is not None:
return client.get_meta(existing)
buffer = client.create_blob(size)
if isinstance(address, (int, np.integer)):
buffer.copy(0, int(address), size)
elif address is not None:
buffer.copy(0, address)
return buffer.seal(client)


def build_numpy_buffer(client, array):
Expand Down
78 changes: 39 additions & 39 deletions src/client/ds/remote_blob.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,45 @@ ObjectID RemoteBlob::id() const { return id_; }

ObjectID RemoteBlob::instance_id() const { return instance_id_; }

void RemoteBlob::Construct(ObjectMeta const& meta) {
std::string __type_name = type_name<RemoteBlob>();
VINEYARD_ASSERT(meta.GetTypeName() == __type_name,
"Expect typename '" + __type_name + "', but got '" +
meta.GetTypeName() + "'");
this->meta_ = meta;
this->id_ = meta.GetId();

if (this->buffer_ != nullptr) {
return;
}
if (this->id_ == EmptyBlobID() || meta.GetNBytes() == 0) {
this->size_ = 0;
return;
}

if (meta.GetClient()->IsRPC() &&
meta.GetClient()->remote_instance_id() != meta.GetInstanceId()) {
throw std::runtime_error(
"RemoteBlob::Construct(): Invalid internal state: remote blob found "
"but it is not located with the instance connected by rpc client");
}

if (meta.GetBuffer(meta.GetId(), this->buffer_).ok()) {
if (this->buffer_ == nullptr) {
throw std::runtime_error(
"RemoteBlob::Construct(): Invalid internal state: remote blob found "
"but it is nullptr: " +
ObjectIDToString(meta.GetId()));
}
this->size_ = this->buffer_->size();
} else {
throw std::runtime_error(
"RemoteBlob::Construct(): Invalid internal state: failed to construct "
"remote blob since payload is missing: " +
ObjectIDToString(meta.GetId()));
}
}

size_t RemoteBlob::size() const { return allocated_size(); }

size_t RemoteBlob::allocated_size() const { return size_; }
Expand Down Expand Up @@ -150,45 +189,6 @@ const std::shared_ptr<MutableBuffer>& RemoteBlobWriter::Buffer() const {

Status RemoteBlobWriter::Abort() { return Status::OK(); }

void RemoteBlob::Construct(ObjectMeta const& meta) {
std::string __type_name = type_name<RemoteBlob>();
VINEYARD_ASSERT(meta.GetTypeName() == __type_name,
"Expect typename '" + __type_name + "', but got '" +
meta.GetTypeName() + "'");
this->meta_ = meta;
this->id_ = meta.GetId();

if (this->buffer_ != nullptr) {
return;
}
if (this->id_ == EmptyBlobID() || meta.GetNBytes() == 0) {
this->size_ = 0;
return;
}

if (meta.GetClient()->IsRPC() &&
meta.GetClient()->remote_instance_id() != meta.GetInstanceId()) {
throw std::runtime_error(
"RemoteBlob::Construct(): Invalid internal state: remote blob found "
"but it is not located with the instance connected by rpc client");
}

if (meta.GetBuffer(meta.GetId(), this->buffer_).ok()) {
if (this->buffer_ == nullptr) {
throw std::runtime_error(
"RemoteBlob::Construct(): Invalid internal state: remote blob found "
"but it is nullptr: " +
ObjectIDToString(meta.GetId()));
}
this->size_ = this->buffer_->size();
} else {
throw std::runtime_error(
"RemoteBlob::Construct(): Invalid internal state: failed to construct "
"remote blob since payload is missing: " +
ObjectIDToString(meta.GetId()));
}
}

void RemoteBlobWriter::Dump() const {
#ifndef NDEBUG
std::stringstream ss;
Expand Down
4 changes: 2 additions & 2 deletions src/client/rpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ Status RPCClient::CreateRemoteBlob(
if (compressor && buffer->size() > 0) {
RETURN_ON_ERROR(detail::compress_and_send(compressor, vineyard_conn_,
buffer->data(), buffer->size()));
} else {
} else if (buffer->size() > 0) {
RETURN_ON_ERROR(send_bytes(vineyard_conn_, buffer->data(), buffer->size()));
}
json message_in;
Expand Down Expand Up @@ -437,7 +437,7 @@ Status RPCClient::GetRemoteBlob(const ObjectID& id, const bool unsafe,
RETURN_ON_ERROR(detail::recv_and_decompress(decompressor, vineyard_conn_,
buffer->mutable_data(),
payloads[0].data_size));
} else {
} else if (payloads[0].data_size > 0) {
RETURN_ON_ERROR(recv_bytes(vineyard_conn_, buffer->mutable_data(),
payloads[0].data_size));
}
Expand Down

0 comments on commit 52ad1a9

Please sign in to comment.