Skip to content

Commit

Permalink
fix: Writing to bytes TextIO stream in StreamWriter causing Attribute…
Browse files Browse the repository at this point in the history
…Error on buffer member (#6034)

* Fix writing to TextIO stream by decoding bytes

 - The TextIO's binary buffer member is not guaranteed to exist
   according to the python docs.
   (https://docs.python.org/3.8/library/io.html#io.TextIOBase.buffer)
 - This change decodes the bytes to a string so that we can use
   the write method on the TextIO stream.

* Update fix to use write instead of buffer (typo)

* Add test for StreamWriter with StringIO stream

* Remove write_bytes and use write_str instead

* Remove unused typing.Union import

* Use new variable to hold output string

---------

Co-authored-by: Mohamed Elasmar <[email protected]>
  • Loading branch information
sbchisholm and moelasmar authored Nov 7, 2023
1 parent b8d372a commit 70b34aa
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 34 deletions.
4 changes: 2 additions & 2 deletions samcli/commands/remote/remote_invoke_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ class DefaultRemoteInvokeResponseConsumer(RemoteInvokeConsumer[RemoteInvokeRespo
_stream_writer: StreamWriter

def consume(self, remote_invoke_response: RemoteInvokeResponse) -> None:
self._stream_writer.write_bytes(cast(str, remote_invoke_response.response).encode())
self._stream_writer.write_str(cast(str, remote_invoke_response.response))


@dataclass
Expand All @@ -290,4 +290,4 @@ class DefaultRemoteInvokeLogConsumer(RemoteInvokeConsumer[RemoteInvokeLogOutput]
_stream_writer: StreamWriter

def consume(self, remote_invoke_response: RemoteInvokeLogOutput) -> None:
self._stream_writer.write_bytes(remote_invoke_response.log_output.encode())
self._stream_writer.write_str(remote_invoke_response.log_output)
16 changes: 1 addition & 15 deletions samcli/lib/utils/stream_writer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
This class acts like a wrapper around output streams to provide any flexibility with output we need
"""
from typing import TextIO, Union
from typing import TextIO


class StreamWriter:
Expand All @@ -23,20 +23,6 @@ def __init__(self, stream: TextIO, auto_flush: bool = False):
def stream(self) -> TextIO:
return self._stream

def write_bytes(self, output: Union[bytes, bytearray]):
"""
Writes specified text to the underlying stream
Parameters
----------
output bytes-like object
Bytes to write into buffer
"""
self._stream.buffer.write(output)

if self._auto_flush:
self._stream.flush()

def write_str(self, output: str):
"""
Writes specified text to the underlying stream
Expand Down
10 changes: 5 additions & 5 deletions samcli/local/docker/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def wait_for_result(self, full_path, event, stdout, stderr, start_timer=None):
if isinstance(response, str):
stdout.write_str(response)
elif isinstance(response, bytes):
stdout.write_bytes(response)
stdout.write_str(response.decode("utf-8"))
stdout.flush()
stderr.write_str("\n")
stderr.flush()
Expand Down Expand Up @@ -473,16 +473,16 @@ def _handle_data_writing(output_stream: Union[StreamWriter, io.BytesIO, io.TextI
# with carriage returns from the RIE. If these are left in the string then only the last line after
# the carriage return will be printed instead of the entire stack trace. Encode the string after cleaning
# to be printed by the correct output stream
output_data = output_data.decode("utf-8").replace("\r", os.linesep).encode("utf-8")
output_str = output_data.decode("utf-8").replace("\r", os.linesep)
if isinstance(output_stream, StreamWriter):
output_stream.write_bytes(output_data)
output_stream.write_str(output_str)
output_stream.flush()

if isinstance(output_stream, io.BytesIO):
output_stream.write(output_data)
output_stream.write(output_str.encode("utf-8"))

if isinstance(output_stream, io.TextIOWrapper):
output_stream.buffer.write(output_data)
output_stream.buffer.write(output_str.encode("utf-8"))

@property
def network_id(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/lib/utils/test_stream_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ def test_must_write_to_stream(self):
stream_mock = Mock()

writer = StreamWriter(stream_mock)
writer.write_bytes(buffer)
writer.write_str(buffer.decode("utf-8"))

stream_mock.buffer.write.assert_called_once_with(buffer)
stream_mock.write.assert_called_once_with(buffer.decode("utf-8"))

def test_must_flush_underlying_stream(self):
stream_mock = Mock()
Expand Down
19 changes: 9 additions & 10 deletions tests/unit/local/docker/test_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,6 @@ def test_wait_for_result_no_error(self, response_deserializable, rie_response, m

stdout_mock = Mock()
stdout_mock.write_str = Mock()
stdout_mock.write_bytes = Mock()
stderr_mock = Mock()
response = Mock()
response.content = rie_response
Expand Down Expand Up @@ -639,7 +638,7 @@ def test_wait_for_result_no_error(self, response_deserializable, rie_response, m
if response_deserializable:
stdout_mock.write_str.assert_called_with(json.dumps(json.loads(rie_response), ensure_ascii=False))
else:
stdout_mock.write_bytes.assert_called_with(rie_response)
stdout_mock.write_str.assert_called_with(rie_response.decode("utf-8"))

@patch("socket.socket")
@patch("samcli.local.docker.container.requests")
Expand Down Expand Up @@ -752,8 +751,8 @@ def _output_iterator():
raise ValueError("The pipe has been ended.")

Container._write_container_output(_output_iterator(), stdout_mock, stderr_mock)
stdout_mock.assert_has_calls([call.write_bytes(b"Hello")])
stderr_mock.assert_has_calls([call.write_bytes(b"World")])
stdout_mock.assert_has_calls([call.write_str("Hello")])
stderr_mock.assert_has_calls([call.write_str("World")])


class TestContainer_wait_for_logs(TestCase):
Expand Down Expand Up @@ -815,25 +814,25 @@ def test_must_write_stdout_and_stderr_data(self):

Container._write_container_output(self.output_itr, stdout=self.stdout_mock, stderr=self.stderr_mock)

self.stdout_mock.write_bytes.assert_has_calls([call(b"stdout1"), call(b"stdout2")])
self.stdout_mock.write_str.assert_has_calls([call("stdout1"), call("stdout2")])

self.stderr_mock.write_bytes.assert_has_calls([call(b"stderr1"), call(b"stderr2")])
self.stderr_mock.write_str.assert_has_calls([call("stderr1"), call("stderr2")])

def test_must_write_only_stderr(self):
# All the invalid frames must be ignored

Container._write_container_output(self.output_itr, stdout=None, stderr=self.stderr_mock)

self.stdout_mock.write_bytes.assert_not_called()
self.stdout_mock.write_str.assert_not_called()

self.stderr_mock.write_bytes.assert_has_calls([call(b"stderr1"), call(b"stderr2")])
self.stderr_mock.write_str.assert_has_calls([call("stderr1"), call("stderr2")])

def test_must_write_only_stdout(self):
Container._write_container_output(self.output_itr, stdout=self.stdout_mock, stderr=None)

self.stdout_mock.write_bytes.assert_has_calls([call(b"stdout1"), call(b"stdout2")])
self.stdout_mock.write_str.assert_has_calls([call("stdout1"), call("stdout2")])

self.stderr_mock.write_bytes.assert_not_called() # stderr must never be called
self.stderr_mock.write_str.assert_not_called() # stderr must never be called


class TestContainer_wait_for_socket_connection(TestCase):
Expand Down

0 comments on commit 70b34aa

Please sign in to comment.