Skip to content

Commit

Permalink
feat: ByteStream auto mime_type detection and base64 (de)encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
LastRemote committed Dec 13, 2024
1 parent 7ade6a2 commit 9dcaa6b
Show file tree
Hide file tree
Showing 4 changed files with 248 additions and 2 deletions.
7 changes: 5 additions & 2 deletions haystack_experimental/dataclasses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

from haystack_experimental.dataclasses.byte_stream import ByteStream
from haystack_experimental.dataclasses.chat_message import (
ChatMessage,
ChatMessageContentT,
Expand All @@ -18,12 +19,14 @@

__all__ = [
"AsyncStreamingCallbackT",
"ByteStream",
"ChatMessage",
"ChatMessageContentT",
"ChatRole",
"MediaContent",
"StreamingCallbackT",
"TextContent",
"ToolCall",
"ToolCallResult",
"TextContent",
"ChatMessageContentT",
"Tool",
]
152 changes: 152 additions & 0 deletions haystack_experimental/dataclasses/byte_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
"""
Data classes for representing binary data in the Haystack API. The ByteStream class can be used to represent binary data
in the API, and can be converted to and from base64 encoded strings, dictionaries, and files. This is particularly
useful for representing media files in chat messages.
"""

import logging
import mimetypes
from base64 import b64encode, b64decode
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, Optional


logger = logging.getLogger(__name__)


@dataclass
class ByteStream:
"""
Base data class representing a binary object in the Haystack API.
"""

data: bytes
meta: Dict[str, Any] = field(default_factory=dict, hash=False)
mime_type: Optional[str] = field(default=None)

@property
def type(self) -> Optional[str]:
"""
Return the type of the ByteStream. This is the first part of the mime type, or None if the mime type is not set.
:return: The type of the ByteStream.
"""
if self.mime_type:
return self.mime_type.split("/", maxsplit=1)[0]
return None

@property
def subtype(self) -> Optional[str]:
"""
Return the subtype of the ByteStream. This is the second part of the mime type,
or None if the mime type is not set.
:return: The subtype of the ByteStream.
"""
if self.mime_type:
return self.mime_type.split("/", maxsplit=1)[-1]
return None

def to_file(self, destination_path: Path):
"""
Write the ByteStream to a file. Note: the metadata will be lost.
:param destination_path: The path to write the ByteStream to.
"""
with open(destination_path, "wb") as fd:
fd.write(self.data)

@classmethod
def from_file_path(
cls, filepath: Path, mime_type: Optional[str] = None, meta: Optional[Dict[str, Any]] = None
) -> "ByteStream":
"""
Create a ByteStream from the contents read from a file.
:param filepath: A valid path to a file.
:param mime_type: The mime type of the file.
:param meta: Additional metadata to be stored with the ByteStream.
"""
if mime_type is None:
mime_type = mimetypes.guess_type(filepath)[0]
if mime_type is None:
logger.warning("Could not determine mime type for file %s", filepath)

with open(filepath, "rb") as fd:
return cls(data=fd.read(), mime_type=mime_type, meta=meta or {})

@classmethod
def from_string(
cls, text: str, encoding: str = "utf-8", mime_type: Optional[str] = None, meta: Optional[Dict[str, Any]] = None
) -> "ByteStream":
"""
Create a ByteStream encoding a string.
:param text: The string to encode
:param encoding: The encoding used to convert the string into bytes
:param mime_type: The mime type of the file.
:param meta: Additional metadata to be stored with the ByteStream.
"""
return cls(data=text.encode(encoding), mime_type=mime_type, meta=meta or {})

def to_string(self, encoding: str = "utf-8") -> str:
"""
Convert the ByteStream to a string, metadata will not be included.
:param encoding: The encoding used to convert the bytes to a string. Defaults to "utf-8".
:returns: The string representation of the ByteStream.
:raises: UnicodeDecodeError: If the ByteStream data cannot be decoded with the specified encoding.
"""
return self.data.decode(encoding)

@classmethod
def from_base64(
cls,
base64_string: str,
encoding: str = "utf-8",
meta: Optional[Dict[str, Any]] = None,
mime_type: Optional[str] = None,
) -> "ByteStream":
"""
Create a ByteStream from a base64 encoded string.
:param base64_string: The base64 encoded string representation of the ByteStream data.
:param encoding: The encoding used to convert the base64 string into bytes.
:param meta: Additional metadata to be stored with the ByteStream.
:param mime_type: The mime type of the file.
:returns: A new ByteStream instance.
"""
return cls(data=b64decode(base64_string.encode(encoding)), meta=meta or {}, mime_type=mime_type)

def to_base64(self, encoding: str = "utf-8") -> str:
"""
Convert the ByteStream data to a base64 encoded string.
:returns: The base64 encoded string representation of the ByteStream data.
"""
return b64encode(self.data).decode(encoding)

@classmethod
def from_dict(cls, data: Dict[str, Any], encoding: str = "utf-8") -> "ByteStream":
"""
Create a ByteStream from a dictionary.
:param data: The dictionary representation of the ByteStream.
:param encoding: The encoding used to convert the base64 string into bytes.
:returns: A new ByteStream instance.
"""
return cls.from_base64(data["data"], encoding=encoding, meta=data.get("meta"), mime_type=data.get("mime_type"))

def to_dict(self, encoding: str = "utf-8"):
"""
Convert the ByteStream to a dictionary.
:param encoding: The encoding used to convert the bytes to a string. Defaults to "utf-8".
:returns: The dictionary representation of the ByteStream.
:raises: UnicodeDecodeError: If the ByteStream data cannot be decoded with the specified encoding.
"""
return {"data": self.to_base64(encoding=encoding), "meta": self.meta, "mime_type": self.mime_type}
Empty file added test/dataclasses/__init__.py
Empty file.
91 changes: 91 additions & 0 deletions test/dataclasses/test_byte_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import pytest
from base64 import b64encode
from pathlib import Path
from unittest.mock import mock_open, patch

from haystack_experimental.dataclasses.byte_stream import ByteStream

@pytest.fixture
def byte_stream():
test_data = b"test data"
test_meta = {"key": "value"}
test_mime = "text/plain"
return ByteStream(data=test_data, meta=test_meta, mime_type=test_mime)

def test_init(byte_stream):
assert byte_stream.data == b"test data"
assert byte_stream.meta == {"key": "value"}
assert byte_stream.mime_type == "text/plain"

def test_type_property(byte_stream):
assert byte_stream.type == "text"
stream_without_mime = ByteStream(data=b"test data")
assert stream_without_mime.type is None

def test_subtype_property(byte_stream):
assert byte_stream.subtype == "plain"
stream_without_mime = ByteStream(data=b"test data")
assert stream_without_mime.subtype is None

@patch("builtins.open", new_callable=mock_open)
def test_to_file(mock_file, byte_stream):
path = Path("test.txt")
byte_stream.to_file(path)
mock_file.assert_called_once_with(path, "wb")
mock_file().write.assert_called_once_with(b"test data")

@patch("builtins.open", new_callable=mock_open, read_data=b"test data")
def test_from_file_path(mock_file):
path = Path("test.txt")
with patch("mimetypes.guess_type", return_value=("text/plain", None)):
byte_stream = ByteStream.from_file_path(path)
assert byte_stream.data == b"test data"
assert byte_stream.mime_type == "text/plain"

@patch("mimetypes.guess_type", return_value=(None, None))
@patch("haystack_experimental.dataclasses.byte_stream.logger.warning")
def test_from_file_path_unknown_mime(mock_warning, _, byte_stream):
path = Path("test.txt")
with patch("builtins.open", new_callable=mock_open, read_data=b"test data"):
byte_stream = ByteStream.from_file_path(path)
assert byte_stream.mime_type is None
mock_warning.assert_called_once()

def test_from_string():
text = "Hello, World!"
byte_stream = ByteStream.from_string(text, mime_type="text/plain")
assert byte_stream.data == text.encode("utf-8")
assert byte_stream.mime_type == "text/plain"

def test_to_string():
byte_stream = ByteStream(data=b"Hello, World!")
assert byte_stream.to_string() == "Hello, World!"

def test_from_base64():
base64_string = b64encode(b"test data").decode("utf-8")
byte_stream = ByteStream.from_base64(base64_string, mime_type="text/plain")
assert byte_stream.data == b"test data"
assert byte_stream.mime_type == "text/plain"

def test_to_base64(byte_stream):
expected = b64encode(b"test data").decode("utf-8")
assert byte_stream.to_base64() == expected

def test_from_dict():
data = {
"data": b64encode(b"test data").decode("utf-8"),
"meta": {"key": "value"},
"mime_type": "text/plain",
}
byte_stream = ByteStream.from_dict(data)
assert byte_stream.data == b"test data"
assert byte_stream.meta == {"key": "value"}
assert byte_stream.mime_type == "text/plain"

def test_to_dict(byte_stream):
expected = {
"data": b64encode(b"test data").decode("utf-8"),
"meta": {"key": "value"},
"mime_type": "text/plain",
}
assert byte_stream.to_dict() == expected

0 comments on commit 9dcaa6b

Please sign in to comment.