Skip to content

Commit

Permalink
[Op] Implement FileSliceSend/FileSliceRecvOp. (#960)
Browse files Browse the repository at this point in the history
FileSliceSend/FileSliceRecv Op transfer scalar string Tensor to/from SliceRecv/SliceSend Op.

Signed-off-by: chenbangduo.cbd <[email protected]>
  • Loading branch information
JackMoriarty authored Dec 26, 2023
1 parent 6bf5621 commit 0f536a2
Show file tree
Hide file tree
Showing 15 changed files with 1,388 additions and 103 deletions.
2 changes: 2 additions & 0 deletions tensorflow/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1203,6 +1203,7 @@ tf_gen_op_libs(
"encode_proto_ops",
"experimental_dataset_ops",
"feature_column_ops",
"file_slice_sendrecv_ops",
"function_ops",
"functional_ops",
"fused_embedding_ops",
Expand Down Expand Up @@ -1465,6 +1466,7 @@ cc_library(
":encode_proto_ops_op_lib",
":experimental_dataset_ops_op_lib",
":feature_column_ops_op_lib",
":file_slice_sendrecv_ops_op_lib",
":function_ops_op_lib",
":functional_ops_op_lib",
":fused_embedding_ops_op_lib",
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/core/framework/rendezvous.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ class Rendezvous : public core::RefCounted {
friend class FuseRecvOp;
friend class SliceSendOp;
friend class SliceRecvOp;
friend class FileSliceSendOp;
friend class FileSliceRecvOp;
friend class RefSendOp;
friend class RefRecvOp;
string buf_;
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/core/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,14 @@ const std::unordered_map<string, Node::NodeClass>& Node::kNodeClassTable =
{"_HostSend", NC_HOST_SEND},
{"_RefSend", NC_REF_SEND},
{"_SliceSend", NC_SLICE_SEND},
{"_FileSliceSend", NC_FILE_SLICE_SEND},
{"_Recv", NC_RECV},
{"_HostRecv", NC_HOST_RECV},
{"_RefRecv", NC_REF_RECV},
{"_FuseRecv", NC_FUSE_RECV},
{"_HostFuseRecv", NC_HOST_FUSE_RECV},
{"_SliceRecv", NC_SLICE_RECV},
{"_FileSliceRecv", NC_FILE_SLICE_RECV},
{"Const", NC_CONSTANT},
{"HostConst", NC_CONSTANT},
{"Variable", NC_VARIABLE},
Expand Down
12 changes: 10 additions & 2 deletions tensorflow/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,15 +220,19 @@ class Node {
bool IsSend() const { return class_ == NC_SEND ||
class_ == NC_HOST_SEND ||
class_ == NC_REF_SEND ||
class_ == NC_SLICE_SEND; }
class_ == NC_SLICE_SEND ||
class_ == NC_FILE_SLICE_SEND; }
bool IsSliceSend() const { return class_ == NC_SLICE_SEND; }
bool IsFileSliceSend() const { return class_ == NC_FILE_SLICE_SEND; }
bool IsRecv() const { return class_ == NC_RECV ||
class_ == NC_HOST_RECV ||
class_ == NC_REF_RECV ||
class_ == NC_SLICE_RECV; }
class_ == NC_SLICE_RECV ||
class_ == NC_FILE_SLICE_RECV; }
bool IsFuseRecv() const { return class_ == NC_FUSE_RECV ||
class_ == NC_HOST_FUSE_RECV; }
bool IsSliceRecv() const {return class_ == NC_SLICE_RECV; }
bool IsFileSliceRecv() const { return class_ == NC_FILE_SLICE_RECV; }
bool IsConstant() const { return class_ == NC_CONSTANT; }
bool IsStage() const { return class_ == NC_TENSOR_BUFFER_PUT; }
bool IsUnstage() const { return class_ == NC_TENSOR_BUFFER_TAKE; }
Expand Down Expand Up @@ -339,12 +343,14 @@ class Node {
NC_HOST_SEND,
NC_REF_SEND,
NC_SLICE_SEND,
NC_FILE_SLICE_SEND,
NC_RECV,
NC_HOST_RECV,
NC_REF_RECV,
NC_FUSE_RECV,
NC_HOST_FUSE_RECV,
NC_SLICE_RECV,
NC_FILE_SLICE_RECV,
NC_CONSTANT,
NC_VARIABLE,
NC_KV_VAR_HANDLE,
Expand Down Expand Up @@ -851,8 +857,10 @@ inline bool IsLoopCond(const Node* node) { return node->IsLoopCond(); }
inline bool IsControlTrigger(const Node* n) { return n->IsControlTrigger(); }
inline bool IsSend(const Node* node) { return node->IsSend(); }
inline bool IsSliceSend(const Node* node) { return node->IsSliceSend(); }
inline bool IsFileSliceSend(const Node* node) { return node->IsFileSliceSend(); }
inline bool IsRecv(const Node* node) { return node->IsRecv(); }
inline bool IsSliceRecv(const Node* node) { return node->IsSliceRecv(); }
inline bool IsFileSliceRecv(const Node* node) { return node->IsFileSliceRecv(); }
inline bool IsFuseRecv(const Node* node) { return node->IsFuseRecv(); }
inline bool IsHostSend(const Node* node) { return node->IsHostSend(); }
inline bool IsHostRecv(const Node* node) { return node->IsHostRecv(); }
Expand Down
10 changes: 8 additions & 2 deletions tensorflow/core/grappler/op_types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,10 @@ bool IsExp(const NodeDef& node) { return node.op() == "Exp"; }

bool IsFakeParam(const NodeDef& node) { return node.op() == "FakeParam"; }

bool IsFileSliceRecv(const NodeDef& node) { return node.op() == "_FileSliceRecv"; }

bool IsFileSliceSend(const NodeDef& node) { return node.op() == "_FileSliceSend"; }

bool IsFill(const NodeDef& node) { return node.op() == "Fill"; }

bool IsFloorDiv(const NodeDef& node) { return node.op() == "FloorDiv"; }
Expand Down Expand Up @@ -454,7 +458,8 @@ bool IsReciprocalGrad(const NodeDef& node) {
}

bool IsRecv(const NodeDef& node) {
return node.op() == "_Recv" || node.op() == "_HostRecv" || IsSliceRecv(node);
return node.op() == "_Recv" || node.op() == "_HostRecv" ||
IsSliceRecv(node) || IsFileSliceRecv(node);
}

bool IsFuseRecv(const NodeDef& node) {
Expand Down Expand Up @@ -502,7 +507,8 @@ bool IsSelect(const NodeDef& node) { return node.op() == "Select"; }
bool IsSeluGrad(const NodeDef& node) { return node.op() == "SeluGrad"; }

bool IsSend(const NodeDef& node) {
return node.op() == "_Send" || node.op() == "_HostSend" || IsSliceSend(node);
return node.op() == "_Send" || node.op() == "_HostSend" ||
IsSliceSend(node) || IsFileSliceSend(node);
}

bool IsShape(const NodeDef& node) { return node.op() == "Shape"; }
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/core/grappler/op_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ bool IsExit(const NodeDef& node);
bool IsExp(const NodeDef& node);
bool IsFakeParam(const NodeDef& node);
bool IsFill(const NodeDef& node);
bool IsFileSliceRecv(const NodeDef& node);
bool IsFileSliceSend(const NodeDef& node);
bool IsFloorDiv(const NodeDef& node);
bool IsFloorMod(const NodeDef& node);
bool IsFusedBatchNorm(const NodeDef& node);
Expand Down
46 changes: 45 additions & 1 deletion tensorflow/core/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -5423,6 +5423,7 @@ cc_library(
name = "required",
deps = [
":no_op",
":file_slice_sendrecv_ops",
":fuserecv_ops",
":sendrecv_ops",
":slice_sendrecv_ops",
Expand All @@ -5446,10 +5447,33 @@ tf_kernel_library(
deps = REQUIRED_DEPS,
)

cc_library(
name = "slice_sendrecv_utils",
hdrs = [
"slice_sendrecv_utils.h"
],
srcs = [
"slice_sendrecv_utils.cc",
],
deps = [
"//tensorflow/core:framework",
]
)

tf_kernel_library(
name = "slice_sendrecv_ops",
prefix = "slice_sendrecv_ops",
deps = REQUIRED_DEPS,
deps = REQUIRED_DEPS + [
":slice_sendrecv_utils",
],
)

tf_kernel_library(
name = "file_slice_sendrecv_ops",
prefix = "file_slice_sendrecv_ops",
deps = REQUIRED_DEPS + [
":slice_sendrecv_utils",
],
)

tf_kernel_library(
Expand Down Expand Up @@ -5534,6 +5558,26 @@ tf_cc_test(
],
)

tf_cc_test(
name = "file_slice_sendrecv_ops_test",
srcs = ["file_slice_sendrecv_ops_test.cc"],
linkstatic = tf_kernel_tests_linkstatic(), # Required for benchmarking
deps = [
":control_flow_ops",
":cwise_op",
":file_slice_sendrecv_ops",
":logging_ops",
":ops_testutil",
":ops_util",
":slice_sendrecv_ops",
":whole_file_read_ops",
"//tensorflow/core:framework",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)

tf_kernel_library(
name = "fuserecv_ops",
prefix = "fuserecv_ops",
Expand Down
Loading

0 comments on commit 0f536a2

Please sign in to comment.