diff --git a/.test.env b/.test.env index 959024d..2ac464c 100644 --- a/.test.env +++ b/.test.env @@ -30,6 +30,8 @@ CHANNEL_VERIFY_LOGS=1012769518828339331 CHANNEL_BOT_COMMANDS=1276953350848588101 CHANNEL_SPOILER=2769521890099371011 CHANNEL_BOT_LOGS=1105517088266788925 +CHANNEL_UNVERIFIED_BOT_COMMANDS=1430556712313688225 +CHANNEL_HOW_TO_VERIFY=1432333413980835840 # Roles ROLE_VERIFIED=1333333333333333337 diff --git a/src/cmds/automation/auto_verify.py b/src/cmds/automation/auto_verify.py index 98c538d..08b8475 100644 --- a/src/cmds/automation/auto_verify.py +++ b/src/cmds/automation/auto_verify.py @@ -4,6 +4,7 @@ from discord.ext import commands from src.bot import Bot +from src.core.config import settings logger = logging.getLogger(__name__) @@ -14,41 +15,29 @@ class MessageHandler(commands.Cog): def __init__(self, bot: Bot): self.bot = bot - async def process_reverification(self, member: Member | User) -> None: - """Re-verifation process for a member. - - TODO: Reimplement once it's possible to fetch link state from the HTB Account. - """ - raise VerificationError("Not implemented") - @commands.Cog.listener() @commands.cooldown(1, 60, commands.BucketType.user) async def on_message(self, ctx: Message) -> None: - """Run commands in the context of a message.""" - # Return if the message was sent by the bot to avoid recursion. + """Guide unverified users toward the verification channel.""" if ctx.author.bot: return - try: - await self.process_reverification(ctx.author) - except VerificationError as exc: - logger.debug(f"HTB Discord link for user {ctx.author.name} with ID {ctx.author.id} not found", exc_info=exc) + if ctx.channel.id == settings.channels.UNVERIFIED_BOT_COMMANDS: + await ctx.reply( + f"Hello! Welcome to the Hack The Box Discord! In order to access the full server, " + f"please verify your account by following the instructions in " + f"<#{settings.channels.HOW_TO_VERIFY}>.", + mention_author=True, + ) + return @commands.Cog.listener() @commands.cooldown(1, 3600, commands.BucketType.user) async def on_member_join(self, member: Member) -> None: """Run commands in the context of a member join.""" - try: - await self.process_reverification(member) - except VerificationError as exc: - logger.debug(f"HTB Discord link for user {member.name} with ID {member.id} not found", exc_info=exc) - - -class VerificationError(Exception): - """Verification error.""" + pass def setup(bot: Bot) -> None: """Load the `MessageHandler` cog.""" - # bot.add_cog(MessageHandler(bot)) - pass + bot.add_cog(MessageHandler(bot)) diff --git a/src/core/config.py b/src/core/config.py index f2a5e42..bc250fe 100644 --- a/src/core/config.py +++ b/src/core/config.py @@ -64,8 +64,13 @@ class Channels(BaseSettings): BOT_COMMANDS: int SPOILER: int BOT_LOGS: int + UNVERIFIED_BOT_COMMANDS: int = 0 + HOW_TO_VERIFY: int = 0 - @validator("DEVLOG", "SR_MOD", "VERIFY_LOGS", "BOT_COMMANDS", "SPOILER", "BOT_LOGS") + @validator( + "DEVLOG", "SR_MOD", "VERIFY_LOGS", "BOT_COMMANDS", "SPOILER", "BOT_LOGS", + "UNVERIFIED_BOT_COMMANDS", "HOW_TO_VERIFY", + ) def check_ids_format(cls, v: list[int]) -> list[int]: """Validate discord ids format.""" if not v: diff --git a/tests/src/cmds/automation/__init__.py b/tests/src/cmds/automation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/src/cmds/automation/test_auto_verify.py b/tests/src/cmds/automation/test_auto_verify.py new file mode 100644 index 0000000..d3923f9 --- /dev/null +++ b/tests/src/cmds/automation/test_auto_verify.py @@ -0,0 +1,57 @@ +from unittest import mock + +import pytest + +from src.cmds.automation import auto_verify +from src.core.config import settings +from tests import helpers + + +class TestMessageHandler: + """Test the `MessageHandler` cog.""" + + @pytest.mark.asyncio + async def test_on_message_in_unverified_channel_sends_welcome(self, bot): + """Test that a welcome message is sent when user posts in unverified channel.""" + cog = auto_verify.MessageHandler(bot) + + channel = helpers.MockTextChannel(id=settings.channels.UNVERIFIED_BOT_COMMANDS) + author = helpers.MockMember(bot=False) + message = helpers.MockMessage(channel=channel, author=author) + message.reply = mock.AsyncMock() + + await cog.on_message(message) + + message.reply.assert_called_once() + call_args = message.reply.call_args + assert "Welcome to the Hack The Box Discord" in call_args[0][0] + assert str(settings.channels.HOW_TO_VERIFY) in call_args[0][0] + assert call_args[1]["mention_author"] is True + + @pytest.mark.asyncio + async def test_on_message_in_other_channel_no_welcome(self, bot): + """Test that no welcome is sent in other channels.""" + cog = auto_verify.MessageHandler(bot) + + channel = helpers.MockTextChannel(id=999999999999999999) + author = helpers.MockMember(bot=False) + message = helpers.MockMessage(channel=channel, author=author) + message.reply = mock.AsyncMock() + + await cog.on_message(message) + + message.reply.assert_not_called() + + @pytest.mark.asyncio + async def test_on_message_from_bot_returns_early(self, bot): + """Test that bot messages are ignored.""" + cog = auto_verify.MessageHandler(bot) + + channel = helpers.MockTextChannel(id=settings.channels.UNVERIFIED_BOT_COMMANDS) + author = helpers.MockMember(bot=True) + message = helpers.MockMessage(channel=channel, author=author) + message.reply = mock.AsyncMock() + + await cog.on_message(message) + + message.reply.assert_not_called() diff --git a/tests/src/helpers/test_ban.py b/tests/src/helpers/test_ban.py index 1771928..e39fe26 100644 --- a/tests/src/helpers/test_ban.py +++ b/tests/src/helpers/test_ban.py @@ -4,7 +4,6 @@ import pytest from discord import Forbidden, HTTPException -from datetime import datetime, timezone from src.helpers.ban import _check_member, _dm_banned_member, ban_member from src.helpers.responses import SimpleResponse @@ -12,12 +11,11 @@ class TestBanHelpers: - @pytest.mark.asyncio async def test__check_member_staff_member(self, bot, guild, member): author = helpers.MockMember(name="Author User") member_is_staff = mock.Mock(return_value=True) - with mock.patch('src.helpers.ban.member_is_staff', member_is_staff): + with mock.patch("src.helpers.ban.member_is_staff", member_is_staff): response = await _check_member(bot, guild, member, author) assert isinstance(response, SimpleResponse) assert response.message == "You cannot ban another staff member." @@ -27,7 +25,7 @@ async def test__check_member_staff_member(self, bot, guild, member): async def test__check_member_regular_member(self, bot, guild, member): author = helpers.MockMember(name="Author User") member_is_staff = mock.Mock(return_value=False) - with mock.patch('src.helpers.ban.member_is_staff', member_is_staff): + with mock.patch("src.helpers.ban.member_is_staff", member_is_staff): response = await _check_member(bot, guild, member, author) assert response is None @@ -37,7 +35,7 @@ async def test__check_member_user(self, bot, guild, user): bot.get_member_or_user = AsyncMock() bot.get_member_or_user.return_value = user response = await _check_member(bot, guild, user, author) - assert await bot.get_member_or_user.called_once_with(guild, user.id) + bot.get_member_or_user.assert_called_once_with(guild, user.id) assert response is None @pytest.mark.asyncio @@ -110,7 +108,6 @@ def __init__(self, status, reason): class TestBanMember: - @pytest.mark.asyncio async def test_ban_member_valid_duration(self, bot, guild, member, author): duration = "1d" @@ -134,8 +131,9 @@ async def test_ban_member_valid_duration(self, bot, guild, member, author): result = await ban_member(bot, guild, member, duration, reason, evidence) assert isinstance(result, SimpleResponse) - assert result.message == f"{member.display_name} ({member.id}) has been banned until {expected_date} " \ - f"(UTC)." + assert ( + result.message == f"{member.display_name} ({member.id}) has been banned until {expected_date} (UTC)." + ) @pytest.mark.asyncio async def test_ban_member_invalid_duration(self, bot, guild, member, author): @@ -196,7 +194,7 @@ async def test_ban_member_no_reason_success(self, bot, guild, member, author): @pytest.mark.asyncio async def test_ban_member_no_author_success(self, bot, guild, member): - duration = '500w' + duration = "500w" reason = "" evidence = "Some evidence" member.display_name = "Banned Member" @@ -216,7 +214,7 @@ async def test_ban_member_no_author_success(self, bot, guild, member): @pytest.mark.asyncio async def test_ban_already_exists(self, bot, guild, member, author): - duration = '500w' + duration = "500w" reason = "" evidence = "Some evidence" member.display_name = "Banned Member" @@ -238,7 +236,7 @@ async def test_ban_already_exists(self, bot, guild, member, author): async def test_ban_member_staff(self, ctx, bot, guild): ctx.user = helpers.MockMember(id=1, name="Test User") user = helpers.MockMember(id=2, name="Banned User") - with patch('src.helpers.ban.member_is_staff', return_value=True): + with patch("src.helpers.ban.member_is_staff", return_value=True): response = await ban_member( bot, guild, user, "1d", "spamming", "some evidence", author=ctx.user, needs_approval=True ) @@ -250,7 +248,7 @@ async def test_ban_member_staff(self, ctx, bot, guild): async def test_ban_member_bot(self, ctx, bot, guild): ctx.user = helpers.MockMember(id=1, name="Test User") member = helpers.MockMember(id=2, name="Bot Member", bot=True) - with patch('src.helpers.ban.member_is_staff', return_value=False): + with patch("src.helpers.ban.member_is_staff", return_value=False): response = await ban_member( bot, guild, member, "1d", "spamming", "some evidence", author=ctx.user, needs_approval=True ) @@ -261,7 +259,7 @@ async def test_ban_member_bot(self, ctx, bot, guild): @pytest.mark.asyncio async def test_ban_self(self, ctx, bot, guild): ctx.user = helpers.MockMember(id=1, name="Test User") - with patch('src.helpers.ban.member_is_staff', return_value=False): + with patch("src.helpers.ban.member_is_staff", return_value=False): response = await ban_member( bot, guild, ctx.user, "1d", "spamming", "some evidence", author=ctx.user, needs_approval=True )