From 9e3145581483b3b6b5558cea0c8dca1a03ddd148 Mon Sep 17 00:00:00 2001 From: SeoulSKY Date: Sun, 1 Sep 2024 14:08:23 -0600 Subject: [PATCH] Use feed url to verify channel ids --- tests/test_youtube_notifier.py | 10 ++++++++++ ytnoti/__init__.py | 23 +++++++++++++++++------ 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/tests/test_youtube_notifier.py b/tests/test_youtube_notifier.py index 5542bed..d600c90 100644 --- a/tests/test_youtube_notifier.py +++ b/tests/test_youtube_notifier.py @@ -5,6 +5,7 @@ import pytest from fastapi import FastAPI from fastapi.testclient import TestClient +from httpx import AsyncClient from tests import CALLBACK_URL, get_video, notifier # noqa: F401 from ytnoti import Video, YouTubeNotifier @@ -213,6 +214,15 @@ def test_get_server(notifier: YouTubeNotifier) -> None: with pytest.raises(ValueError): notifier._get_server(host=host, port=port, app=app) +@pytest.mark.asyncio +async def test_verify_channel_id() -> None: + """Test verifying channel ID.""" + notifier = YouTubeNotifier() + async with AsyncClient() as client: + for channel_id in channel_ids: + assert await notifier._verify_channel_id(channel_id, client=client) + + assert not await notifier._verify_channel_id("invalid", client=client) def test_get(notifier: YouTubeNotifier) -> None: """Test the get method of the YouTubeNotifier class.""" diff --git a/ytnoti/__init__.py b/ytnoti/__init__.py index 42afc6f..f41def8 100644 --- a/ytnoti/__init__.py +++ b/ytnoti/__init__.py @@ -422,6 +422,22 @@ async def _is_listening(self) -> bool: return response.status_code == HTTPStatus.OK.value + @staticmethod + async def _verify_channel_id(channel_id: str, *, client: AsyncClient) -> bool: + """Verify the channel ID by sending a HEAD request to the YouTube channel. + + :param channel_id: The channel ID to verify. + :param client: The asynchronous HTTP client to use for the request. + :return: True if the channel ID is valid, False otherwise. + :raises HTTPError: If failed to verify the channel ID due to an HTTP error. + """ + response = await client.head( + f"https://www.youtube.com/feeds/videos.xml?channel_id={channel_id}" + ) + + return response.status_code == HTTPStatus.OK.value + + async def subscribe(self, channel_ids: str | Iterable[str]) -> Self: """Subscribe to YouTube channels to receive push notifications. This is lazy and will subscribe when the notifier is ready. @@ -438,12 +454,7 @@ async def subscribe(self, channel_ids: str | Iterable[str]) -> Self: async with AsyncClient() as client: for channel_id in channel_ids: - response = await client.head( - f"https://www.youtube.com/channel/{channel_id}" - ) - - if response.status_code != HTTPStatus.OK.value: - raise ValueError(f"Invalid channel ID: {channel_id}") + await self._verify_channel_id(channel_id, client=client) if not self.is_ready: self._subscribed_ids.update(channel_ids)