diff --git a/telegram/ext/filters.py b/telegram/ext/filters.py index 1942fa8593d..28b15ee070e 100644 --- a/telegram/ext/filters.py +++ b/telegram/ext/filters.py @@ -324,6 +324,79 @@ def filter(self, message): group = _Group() + class user(BaseFilter): + """Filters messages to allow only those which are from specified user ID. + + Notes: + Only one of chat_id or username must be used here. + + Args: + user_id(Optional[int|list]): which user ID(s) to allow through. + username(Optional[str|list]): which username(s) to allow through. If username starts + with '@' symbol, it will be ignored. + + Raises: + ValueError + """ + + def __init__(self, user_id=None, username=None): + if not (bool(user_id) ^ bool(username)): + raise ValueError('One and only one of user_id or username must be used') + if user_id is not None and isinstance(user_id, int): + self.user_ids = [user_id] + else: + self.user_ids = user_id + if username is None: + self.usernames = username + elif isinstance(username, str_type): + self.usernames = [username.replace('@', '')] + else: + self.usernames = [user.replace('@', '') for user in username] + + def filter(self, message): + if self.user_ids is not None: + return bool(message.from_user and message.from_user.id in self.user_ids) + else: + # self.usernames is not None + return bool(message.from_user and message.from_user.username and + message.from_user.username in self.usernames) + + class chat(BaseFilter): + """Filters messages to allow only those which are from specified chat ID. + + Notes: + Only one of chat_id or username must be used here. + + Args: + chat_id(Optional[int|list]): which chat ID(s) to allow through. + username(Optional[str|list]): which username(s) to allow through. If username starts + with '@' symbol, it will be ignored. + + Raises: + ValueError + """ + + def __init__(self, chat_id=None, username=None): + if not (bool(chat_id) ^ bool(username)): + raise ValueError('One and only one of chat_id or username must be used') + if chat_id is not None and isinstance(chat_id, int): + self.chat_ids = [chat_id] + else: + self.chat_ids = chat_id + if username is None: + self.usernames = username + elif isinstance(username, str_type): + self.usernames = [username.replace('@', '')] + else: + self.usernames = [chat.replace('@', '') for chat in username] + + def filter(self, message): + if self.chat_ids is not None: + return bool(message.chat_id in self.chat_ids) + else: + # self.usernames is not None + return bool(message.chat.username and message.chat.username in self.usernames) + class _Invoice(BaseFilter): def filter(self, message): diff --git a/tests/test_filters.py b/tests/test_filters.py index e8c4c6637d2..2fcded4bc9b 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -213,6 +213,51 @@ def test_group_fileter(self): self.message.chat.type = "supergroup" self.assertTrue(Filters.group(self.message)) + def test_filters_chat(self): + with self.assertRaisesRegexp(ValueError, 'chat_id or username'): + Filters.chat(chat_id=-1, username='chat') + with self.assertRaisesRegexp(ValueError, 'chat_id or username'): + Filters.chat() + + def test_filters_chat_id(self): + self.assertFalse(Filters.chat(chat_id=-1)(self.message)) + self.message.chat.id = -1 + self.assertTrue(Filters.chat(chat_id=-1)(self.message)) + self.message.chat.id = -2 + self.assertTrue(Filters.chat(chat_id=[-1, -2])(self.message)) + self.assertFalse(Filters.chat(chat_id=-1)(self.message)) + + def test_filters_chat_username(self): + self.assertFalse(Filters.chat(username='chat')(self.message)) + self.message.chat.username = 'chat' + self.assertTrue(Filters.chat(username='@chat')(self.message)) + self.assertTrue(Filters.chat(username='chat')(self.message)) + self.assertTrue(Filters.chat(username=['chat1', 'chat', 'chat2'])(self.message)) + self.assertFalse(Filters.chat(username=['@chat1', 'chat_2'])(self.message)) + + def test_filters_user(self): + with self.assertRaisesRegexp(ValueError, 'user_id or username'): + Filters.user(user_id=1, username='user') + with self.assertRaisesRegexp(ValueError, 'user_id or username'): + Filters.user() + + def test_filters_user_id(self): + self.assertFalse(Filters.user(user_id=1)(self.message)) + self.message.from_user.id = 1 + self.assertTrue(Filters.user(user_id=1)(self.message)) + self.message.from_user.id = 2 + self.assertTrue(Filters.user(user_id=[1, 2])(self.message)) + self.assertFalse(Filters.user(user_id=1)(self.message)) + + def test_filters_username(self): + self.assertFalse(Filters.user(username='user')(self.message)) + self.assertFalse(Filters.user(username='Testuser')(self.message)) + self.message.from_user.username = 'user' + self.assertTrue(Filters.user(username='@user')(self.message)) + self.assertTrue(Filters.user(username='user')(self.message)) + self.assertTrue(Filters.user(username=['user1', 'user', 'user2'])(self.message)) + self.assertFalse(Filters.user(username=['@username', '@user_2'])(self.message)) + def test_and_filters(self): self.message.text = 'test' self.message.forward_date = True
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: