Skip to content

Commit

Permalink
Use feed url to verify channel ids
Browse files Browse the repository at this point in the history
  • Loading branch information
SeoulSKY committed Sep 1, 2024
1 parent af814dc commit 9e31455
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 6 deletions.
10 changes: 10 additions & 0 deletions tests/test_youtube_notifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from httpx import AsyncClient

from tests import CALLBACK_URL, get_video, notifier # noqa: F401
from ytnoti import Video, YouTubeNotifier
Expand Down Expand Up @@ -213,6 +214,15 @@ def test_get_server(notifier: YouTubeNotifier) -> None:
with pytest.raises(ValueError):
notifier._get_server(host=host, port=port, app=app)

@pytest.mark.asyncio
async def test_verify_channel_id() -> None:
"""Test verifying channel ID."""
notifier = YouTubeNotifier()
async with AsyncClient() as client:
for channel_id in channel_ids:
assert await notifier._verify_channel_id(channel_id, client=client)

assert not await notifier._verify_channel_id("invalid", client=client)

def test_get(notifier: YouTubeNotifier) -> None:
"""Test the get method of the YouTubeNotifier class."""
Expand Down
23 changes: 17 additions & 6 deletions ytnoti/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,22 @@ async def _is_listening(self) -> bool:

return response.status_code == HTTPStatus.OK.value

@staticmethod
async def _verify_channel_id(channel_id: str, *, client: AsyncClient) -> bool:
"""Verify the channel ID by sending a HEAD request to the YouTube channel.
:param channel_id: The channel ID to verify.
:param client: The asynchronous HTTP client to use for the request.
:return: True if the channel ID is valid, False otherwise.
:raises HTTPError: If failed to verify the channel ID due to an HTTP error.
"""
response = await client.head(
f"https://www.youtube.com/feeds/videos.xml?channel_id={channel_id}"
)

return response.status_code == HTTPStatus.OK.value


async def subscribe(self, channel_ids: str | Iterable[str]) -> Self:
"""Subscribe to YouTube channels to receive push notifications.
This is lazy and will subscribe when the notifier is ready.
Expand All @@ -438,12 +454,7 @@ async def subscribe(self, channel_ids: str | Iterable[str]) -> Self:

async with AsyncClient() as client:
for channel_id in channel_ids:
response = await client.head(
f"https://www.youtube.com/channel/{channel_id}"
)

if response.status_code != HTTPStatus.OK.value:
raise ValueError(f"Invalid channel ID: {channel_id}")
await self._verify_channel_id(channel_id, client=client)

if not self.is_ready:
self._subscribed_ids.update(channel_ids)
Expand Down

0 comments on commit 9e31455

Please sign in to comment.