Skip to content

Commit

Permalink
new lavalink and wavelink updates/fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Brettanda committed May 4, 2024
1 parent 0634d7d commit 506c2dc
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 98 deletions.
164 changes: 69 additions & 95 deletions cogs/music.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import copy as co
from yarl import URL
from typing_extensions import Annotated
from wavelink.ext import spotify

from functions import (MessageColors, checks, config, embed, exceptions,
paginator)
Expand Down Expand Up @@ -168,79 +167,56 @@ class Track(wavelink.Playable):
requester: discord.Member


class SearchType(Enum):
Generic = wavelink.GenericTrack # direct playback url
YouTube = wavelink.YouTubeTrack
YouTubeMusic = wavelink.YouTubeMusicTrack
YouTubePlaylist = wavelink.YouTubePlaylist
SoundCloud = wavelink.SoundCloudTrack
SoundCloudPlaylist = wavelink.SoundCloudPlaylist
Spotify = spotify.SpotifyTrack


class CustomSearch(discord.app_commands.Transformer):
@staticmethod
def get_platform(value: str) -> SearchType:
def get_platform(value: str) -> str:
if value.startswith("http"):
link = URL(value)
log.info(link.host)
log.info(link.path)
log.info(link.query)
if link.host and ("youtube.com" in link.host or "youtu.be" in link.host):
if link.path == "/playlist" or link.query.get("list") is not None:
return SearchType.YouTubePlaylist
return SearchType.YouTube
return "youtubeplaylist"
return "youtube"
elif link.host and "soundcloud.com" in link.host:
return SearchType.SoundCloud
return "soundcloud"
elif link.host and "spotify.com" in link.host:
return SearchType.Spotify
return SearchType.Generic
return SearchType.YouTube

@classmethod
async def transform(cls, interaction: discord.Interaction, value: str, /) -> list[Track]:
platform = cls.get_platform(value)
value = value.strip("<>")
log.info(platform)
if platform is SearchType.Spotify:
decoded = spotify.decode_url(value)
if decoded:
if decoded['type'] is spotify.SpotifySearchType.playlist:
tracks = []
async for t in spotify.SpotifyTrack.iterator(query=value):
t.requester = interaction.user # type: ignore
tracks.append(t)
return tracks
tracks = await platform.value.search(value)
if platform == SearchType.YouTubePlaylist and isinstance(tracks, wavelink.YouTubePlaylist):
tracks = tracks.tracks
for t in tracks:
t.requester = interaction.user # type: ignore
return tracks # type: ignore
return "spotify"
return "youtube"

# @classmethod
# async def transform(cls, interaction: discord.Interaction, value: str, /) -> list[Track]:
# platform = cls.get_platform(value)
# value = value.strip("<>")
# log.info(platform)
# if platform is SearchType.Spotify:
# decoded = spotify.decode_url(value)
# if decoded:
# if decoded['type'] is spotify.SpotifySearchType.playlist:
# tracks = []
# async for t in spotify.SpotifyTrack.iterator(query=value):
# t.requester = interaction.user
# tracks.append(t)
# return tracks
# tracks = await platform.value.search(value)
# if platform == SearchType.YouTubePlaylist and isinstance(tracks, wavelink.YouTubePlaylist):
# tracks = tracks.tracks
# for t in tracks:
# t.requester = interaction.user
# return tracks

async def convert(self, ctx: MyContext, value: str) -> list[Track]:
platform = self.get_platform(value)
value = value.strip("<>")
log.info(platform)
tracks: list[Track] = []
if platform is SearchType.Spotify:
decoded = spotify.decode_url(value)
if decoded:
if decoded['type'] is spotify.SpotifySearchType.playlist:
async for t in spotify.SpotifyTrack.iterator(query=value):
t.requester = ctx.author # type: ignore
tracks.append(t) # type: ignore
return tracks
tracks = await platform.value.search(value) # type: ignore
if platform == SearchType.YouTubePlaylist and isinstance(tracks, wavelink.YouTubePlaylist):
tracks = tracks.tracks # type: ignore
tracks = await wavelink.Playable.search(value)
if isinstance(tracks, list):
for t in tracks:
t.requester = ctx.author # type: ignore
t.requester = ctx.author
return tracks

async def autocomplete(self, interaction: discord.Interaction, value: str, /) -> List[discord.app_commands.Choice[str]]:
platform = self.get_platform(value)
search = await platform.value.search(value)
search = await wavelink.Playable.search(value)
return [discord.app_commands.Choice(name=track.title, value=track.uri or track.title) for track in search[:25]]


Expand Down Expand Up @@ -275,11 +251,11 @@ def __init__(
super().__init__(*args, **kwargs)


class TrackEventPayload(wavelink.TrackEventPayload):
player: Player
track: Track
ctx: GuildContext
original: Track | None
# class TrackEventPayload(wavelink.TrackEventPayload):
# player: Player
# track: Track
# ctx: GuildContext
# original: Track | None


NUMTOEMOTES = {
Expand Down Expand Up @@ -323,18 +299,13 @@ def __repr__(self) -> str:
async def cog_load(self):
nodes = [
wavelink.Node(
id=f"{os.environ.get('LAVALINKUSID','MAIN')}",
identifier=f"{os.environ.get('LAVALINKUSID','MAIN')}",
uri=f"http://{os.environ['LAVALINKUSHOST']}:{os.environ['LAVALINKUSPORT']}",
password=os.environ["LAVALINKUSPASS"],
)
]

spotify_client = spotify.SpotifyClient(
client_id=os.environ["SPOTIFYID"],
client_secret=os.environ["SPOTIFYSECRET"],
)

await wavelink.NodePool.connect(client=self.bot, nodes=nodes, spotify=spotify_client)
await wavelink.Pool.connect(client=self.bot, nodes=nodes)#, spotify=spotify_client)

def cog_check(self, ctx: MyContext) -> bool:
if not ctx.guild:
Expand All @@ -360,11 +331,15 @@ async def cog_command_error(self, ctx: MyContext, error: commands.CommandError):
log.error(f"Error in {ctx.command.qualified_name}: {type(error).__name__}: {error}")

@commands.Cog.listener()
async def on_wavelink_node_ready(self, node: wavelink.Node):
log.info(f"Node {node.id} is ready!")
async def on_wavelink_node_ready(self, payload: wavelink.NodeReadyEventPayload):
log.info(f"Node {payload.node.identifier} is ready!")

@commands.Cog.listener()
async def on_wavelink_track_exception(self, payload: wavelink.TrackExceptionEventPayload) -> None:
print(payload.exception)

@commands.Cog.listener()
async def on_wavelink_track_start(self, payload: TrackEventPayload):
async def on_wavelink_track_start(self, payload: wavelink.TrackStartEventPayload): # TrackEventPayload
if payload.player.channel is None:
return

Expand All @@ -390,11 +365,11 @@ async def on_wavelink_track_start(self, payload: TrackEventPayload):
# await payload.player.channel.create_instance(topic=f"🎵 {payload.track.title}{' by ' + payload.track.requester if payload.track.requester is not None else ''}", reason="Next track started.")

@commands.Cog.listener()
async def on_wavelink_track_end(self, payload: TrackEventPayload):
if not payload.player.queue.is_empty:
async def on_wavelink_inactive_player(self, player: wavelink.Player): #TrackEventPayload):
if not player.queue.is_empty:
return

await payload.player.disconnect()
await player.disconnect()

def required(self, ctx: GuildContext, player: Player) -> int:
channel = player.channel
Expand Down Expand Up @@ -470,7 +445,6 @@ async def play(self, ctx: GuildContext, *, query: discord.app_commands.Transform
player.ctx = ctx
else:
player: Player = ctx.voice_client # type: ignore
player.autoplay = True

# if player.channel.instance is None:
# await player.channel.create_instance(topic=track.title, reason="Music time!")
Expand All @@ -484,7 +458,6 @@ async def play(self, ctx: GuildContext, *, query: discord.app_commands.Transform
await ctx.guild.me.request_to_speak()

new_tracks = []
is_playing = co.deepcopy(player.is_playing())

async def _play(track):
track.requester = ctx.author
Expand All @@ -498,10 +471,10 @@ async def _play(track):
except BaseException:
raise TrackNotFound()

if not player.is_playing():
await player.play(await player.queue.get_wait())
if not player.playing:
await player.play(query[0])

if is_playing:
if player.playing:
if len(new_tracks) == 1:
await ctx.send(embed=embed(title=f"Added **{new_tracks[0].title}** to the queue.", color=MessageColors.music()))
else:
Expand All @@ -515,22 +488,22 @@ async def pause(self, ctx: GuildContext):
"""Pause the currently playing song."""
player: Player | None = ctx.voice_client # type: ignore

if player is None or player.is_paused() or player.current is None:
if player is None or player.paused or player.current is None:
raise NothingPlaying()

if self.is_privileged(ctx):
await ctx.send(embed=embed(title='An admin or DJ has paused the player.', color=MessageColors.music()))
player.pause_votes.clear()

return await player.pause()
return await player.pause(True)

required = self.required(ctx, player=player)
player.pause_votes.add(ctx.author)

if len(player.pause_votes) >= required:
await ctx.send(embed=embed(title='Vote to pause passed. Pausing player.', color=MessageColors.music()))
player.pause_votes.clear()
await player.pause()
await player.pause(True)
else:
await ctx.send(embed=embed(title=f'{ctx.author} has voted to pause the player.', color=MessageColors.music()))

Expand All @@ -542,22 +515,22 @@ async def resume(self, ctx: GuildContext):
"""Resume a currently paused player."""
player: Player | None = ctx.voice_client # type: ignore

if player is None or not player.is_paused() or player.current is None:
if player is None or not player.paused or player.current is None:
raise NothingPlaying()

if self.is_privileged(ctx):
await ctx.send(embed=embed(title='An admin or DJ has resumed the player.', color=MessageColors.music()))
player.resume_votes.clear()

return await player.resume()
return await player.pause(False)

required = self.required(ctx, player)
player.resume_votes.add(ctx.author)

if len(player.resume_votes) >= required:
await ctx.send(embed=embed(title='Vote to resume passed. Resuming player.', color=MessageColors.music()))
player.resume_votes.clear()
await player.resume()
await player.pause(False)
else:
await ctx.send(embed=embed(title=f'{ctx.author.mention} has voted to resume the player.', color=MessageColors.music()))

Expand All @@ -574,12 +547,11 @@ async def loop(self, ctx: GuildContext, type: Literal['all', 'one'] = None) -> N

def set_loop(type):
if type == "one":
player.queue.loop = not player.queue.loop
player.queue.mode = wavelink.QueueMode.loop
elif type == "all":
player.queue.loop_all = not player.queue.loop_all
player.queue.mode = wavelink.QueueMode.loop_all
else:
player.queue.loop = False
player.queue.loop_all = False
player.queue.mode = wavelink.QueueMode.normal

if self.is_privileged(ctx):
await ctx.send(embed=embed(title=f'An admin or DJ has set the loop to {type}.', color=MessageColors.music()))
Expand All @@ -603,21 +575,21 @@ async def skip(self, ctx: GuildContext):
await ctx.send(embed=embed(title='An admin or DJ has skipped the song.', color=MessageColors.music()))
player.skip_votes.clear()

return await player.stop()
return await player.skip(force=False)

if hasattr(player.current, "requester") and ctx.author == player.current.requester:
await ctx.send(embed=embed(title='The song requester has skipped the song.', color=MessageColors.music()))
player.skip_votes.clear()

return await player.stop()
return await player.skip(force=False)

required = self.required(ctx, player)
player.skip_votes.add(ctx.author)

if len(player.skip_votes) >= required:
await ctx.send(embed=embed(title='Vote to skip passed. Skipping song.', color=MessageColors.music()))
player.skip_votes.clear()
await player.stop()
await player.skip(force=False)
else:
await ctx.send(embed=embed(title=f'{ctx.author} has voted to skip the song. {len(player.skip_votes)}/{required}', color=MessageColors.music()))

Expand Down Expand Up @@ -720,8 +692,10 @@ async def equalizer(self, ctx: GuildContext, *, equalizer: Annotated[str, Equali
joined = "\n".join(eqs.keys())
return await ctx.send(embed=embed(title=f'Invalid EQ provided. Valid EQs:\n\n{joined}', color=MessageColors.error()))

filters = player.filters
filters.equalizer.set(bands=eq.eq)
await player.set_filters(filters)
await ctx.send(embed=embed(title=f'Successfully changed equalizer to {equalizer}', color=MessageColors.music()))
await player.set_filter(wavelink.Filter(equalizer=eq))

@commands.group(name="queue", aliases=['que'], invoke_without_command=True)
@commands.guild_only()
Expand All @@ -737,7 +711,7 @@ async def queue(self, ctx: GuildContext):
if player.queue.is_empty:
return await ctx.send(embed=embed(title='There are no more songs in the queue.', color=MessageColors.error()))

entries = [track.title for track in player.queue._queue]
entries = [track.title for track in player.queue]
source = PaginatorSource(entries=entries)
pages = paginator.RoboPages(source=source, ctx=ctx, compact=True)

Expand All @@ -753,9 +727,9 @@ async def queue_remove(self, ctx: GuildContext, *, index: int):
if not player or not player.current:
raise NothingPlaying()

if index < 1 or index > len(player.queue._queue):
if index < 1 or index > player.queue.count:
return await ctx.send(embed=embed(title='Invalid index provided.', color=MessageColors.error()))
entry = player.queue._queue[index - 1]
entry = player.queue[index - 1]

if self.is_privileged(ctx):
del player.queue[index - 1]
Expand Down
2 changes: 1 addition & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ services:
retries: 5
start_period: 5s
lavalink:
image: fredboat/lavalink:3.7.8
image: ghcr.io/lavalink-devs/lavalink:4
restart: unless-stopped
env_file: .env
ports:
Expand Down
Loading

0 comments on commit 506c2dc

Please sign in to comment.