283 lines
9.5 KiB
Python
283 lines
9.5 KiB
Python
"""
|
||
Middleware для проверки сообщений на запрещённые слова (банворды).
|
||
"""
|
||
|
||
from typing import Callable, Dict, Any, Awaitable, Optional
|
||
import re
|
||
import unicodedata
|
||
|
||
from aiogram import BaseMiddleware
|
||
from aiogram.types import Message
|
||
from aiogram.exceptions import TelegramBadRequest
|
||
|
||
from configs import settings, UNICODE_MAP, LATIN_TO_CYRILLIC, CYRILLIC_NORMALIZE
|
||
from database import get_manager, BanWordType
|
||
from bot.special import extract_words, get_lemma
|
||
from middleware.loggers import logger
|
||
|
||
|
||
__all__ = ("BanWordsMiddleware",)
|
||
URL_PATTERN = re.compile(
|
||
r'(https?://\S+|www\.\S+)',
|
||
re.IGNORECASE
|
||
)
|
||
|
||
class TextNormalizer:
|
||
"""
|
||
Класс для многоступенчатой нормализации текста.
|
||
"""
|
||
|
||
FULL_MAP: Dict[str, str] = {}
|
||
FULL_MAP.update(LATIN_TO_CYRILLIC)
|
||
FULL_MAP.update(CYRILLIC_NORMALIZE)
|
||
FULL_MAP.update(UNICODE_MAP)
|
||
|
||
SEPARATORS = re.compile(r'[\s.\-_,;:|]+', re.UNICODE)
|
||
REPEAT_PATTERN = re.compile(r'([а-яёa-z])\1{2,}', re.IGNORECASE)
|
||
|
||
@classmethod
|
||
def normalize_characters(cls, text: str) -> str:
|
||
result: list[str] = []
|
||
for ch in text:
|
||
result.append(cls.FULL_MAP.get(ch, ch))
|
||
return ''.join(result).lower()
|
||
|
||
@classmethod
|
||
def remove_separators(cls, text: str) -> str:
|
||
return cls.SEPARATORS.sub('', text)
|
||
|
||
@classmethod
|
||
def collapse_repeats(cls, text: str) -> str:
|
||
def repl(match: re.Match[str]) -> str:
|
||
return match.group(1)
|
||
return cls.REPEAT_PATTERN.sub(repl, text)
|
||
|
||
@classmethod
|
||
def normalize_full(
|
||
cls,
|
||
text: str,
|
||
remove_sep: bool = True,
|
||
collapse: bool = True
|
||
) -> str:
|
||
text = unicodedata.normalize('NFKC', text)
|
||
text = cls.normalize_characters(text)
|
||
|
||
if remove_sep:
|
||
text = cls.remove_separators(text)
|
||
|
||
if collapse:
|
||
text = cls.collapse_repeats(text)
|
||
|
||
return text
|
||
|
||
@classmethod
|
||
def normalize_for_part_token(cls, text: str) -> str:
|
||
"""
|
||
Нормализация для PART:
|
||
- NFKC
|
||
- lower()
|
||
- удаление zero-width
|
||
- схлопывание повторов латиницы
|
||
- БЕЗ LATIN_TO_CYRILLIC
|
||
"""
|
||
text = unicodedata.normalize('NFKC', text)
|
||
text = text.lower()
|
||
|
||
# удаляем zero-width
|
||
text = re.sub(r'[\u200B-\u200D\uFEFF]', '', text)
|
||
|
||
# схлопываем повторы букв (3+ → 1)
|
||
text = re.sub(r'([a-z])\1+', r'\1', text)
|
||
|
||
return text
|
||
|
||
|
||
class BanWordsMiddleware(BaseMiddleware):
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
self.manager = get_manager()
|
||
self.normalizer = TextNormalizer()
|
||
|
||
async def __call__(
|
||
self,
|
||
handler: Callable[[Message, Dict[str, Any]], Awaitable[Any]],
|
||
event: Message,
|
||
data: Dict[str, Any]
|
||
) -> Any:
|
||
|
||
if not event.text and not event.caption:
|
||
return await handler(event, data)
|
||
|
||
message_text: str = event.text or event.caption
|
||
|
||
if message_text.startswith('/'):
|
||
return await handler(event, data)
|
||
|
||
user_id: int = event.from_user.id
|
||
is_super_admin: bool = user_id in settings.OWNER_ID
|
||
is_admin: bool = is_super_admin or self.manager.is_admin_cached(user_id)
|
||
|
||
if is_admin:
|
||
return await handler(event, data)
|
||
|
||
spam_result = await self._check_message(message_text)
|
||
|
||
if spam_result:
|
||
await self._handle_spam(event)
|
||
return None
|
||
|
||
return await handler(event, data)
|
||
|
||
@staticmethod
|
||
def is_allowed_url(url: str, allowed: str) -> bool:
|
||
url_lower = url.lower()
|
||
allowed_lower = allowed.lower()
|
||
if allowed_lower.endswith('/'):
|
||
# исключение со слешем: только строгое начало с этим слешем
|
||
return url_lower.startswith(allowed_lower)
|
||
else:
|
||
# исключение без слеша: разрешаем точное совпадение или начало с добавлением слеша
|
||
return url_lower == allowed_lower or url_lower.startswith(allowed_lower + '/')
|
||
|
||
async def _check_message(self, text: str) -> Optional[Dict[str, str]]:
|
||
whitelist = {
|
||
w.lower().strip()
|
||
for w in self.manager.get_whitelist_cached()
|
||
}
|
||
|
||
# ================= URL CHECK =================
|
||
urls = URL_PATTERN.findall(text)
|
||
|
||
for url in urls:
|
||
url_lower = url.lower()
|
||
|
||
# если URL начинается с разрешённого исключения — пропускаем
|
||
if any(self.is_allowed_url(url_lower, allowed) for allowed in whitelist):
|
||
continue
|
||
|
||
# если нет разрешения — проверяем WORD-правила для URL
|
||
for word in self.manager.get_banwords_cached(BanWordType.WORD):
|
||
if word in url_lower:
|
||
return {"word": word, "type": "word"}
|
||
# =============================================
|
||
|
||
# 2. Убираем URL из текста для word/lemma проверки
|
||
text_without_urls = URL_PATTERN.sub(' ', text)
|
||
|
||
word_words = self.manager.get_banwords_cached(BanWordType.WORD)
|
||
lemma_words = self.manager.get_banwords_cached(BanWordType.LEMMA)
|
||
part_words = self.manager.get_banwords_cached(BanWordType.PART)
|
||
conflict_word = self.manager.get_banwords_cached(BanWordType.CONFLICT_WORD)
|
||
conflict_lemma = self.manager.get_banwords_cached(BanWordType.CONFLICT_LEMMA)
|
||
conflict_part = self.manager.get_banwords_cached(BanWordType.CONFLICT_PART)
|
||
|
||
if await self.manager.is_silence_active():
|
||
return {"word": "[режим тишины]", "type": "silence"}
|
||
|
||
if await self.manager.is_conflict_active():
|
||
normalized_text = self.normalizer.normalize_full(text)
|
||
|
||
for word in conflict_word:
|
||
if self.normalizer.normalize_full(word) in normalized_text:
|
||
return {"word": word, "type": "conflict_word"}
|
||
|
||
for word_text in extract_words(text):
|
||
if get_lemma(word_text) in conflict_lemma:
|
||
return {"word": word_text, "type": "conflict_lemma"}
|
||
|
||
return None
|
||
|
||
# WORD — строгое совпадение как отдельное слово
|
||
for word in word_words:
|
||
pattern = r'(?<!\w){}(?!\w)'.format(re.escape(word))
|
||
|
||
for match in re.finditer(pattern, text_without_urls, re.IGNORECASE):
|
||
matched = match.group(0).lower()
|
||
|
||
# если совпавшее слово в whitelist — игнорируем
|
||
if matched in whitelist:
|
||
continue
|
||
|
||
# если это начало URL — пропускаем
|
||
if text[match.end():match.end() + 3] == '://':
|
||
continue
|
||
|
||
return {"word": word, "type": "word"}
|
||
|
||
# PART
|
||
usernames = re.findall(r'@[\w_]+', text_without_urls)
|
||
latin_tokens = re.findall(r'\b[a-zA-Z0-9_]*[a-zA-Z]+[a-zA-Z0-9_]*\b', text_without_urls)
|
||
|
||
tokens_to_check = usernames + latin_tokens
|
||
|
||
# PART
|
||
for token in tokens_to_check:
|
||
token_lower = token.lower()
|
||
|
||
# если именно этот токен разрешён
|
||
normalized_for_whitelist = token_lower.lstrip('@')
|
||
|
||
if (
|
||
token_lower in whitelist or
|
||
normalized_for_whitelist in whitelist or
|
||
f"@{normalized_for_whitelist}" in whitelist
|
||
):
|
||
continue
|
||
|
||
normalized_token = self.normalizer.normalize_for_part_token(token)
|
||
|
||
for part in part_words:
|
||
norm_part = self.normalizer.normalize_for_part_token(part)
|
||
|
||
if norm_part in normalized_token:
|
||
return {"word": part, "type": "part"}
|
||
|
||
# CONFLICT PART
|
||
for token in tokens_to_check:
|
||
token_lower = token.lower()
|
||
|
||
normalized_for_whitelist = token_lower.lstrip('@')
|
||
|
||
if (
|
||
token_lower in whitelist or
|
||
normalized_for_whitelist in whitelist or
|
||
f"@{normalized_for_whitelist}" in whitelist
|
||
):
|
||
continue
|
||
|
||
normalized_token = self.normalizer.normalize_for_part_token(token)
|
||
|
||
for part in conflict_part:
|
||
norm_part = self.normalizer.normalize_for_part_token(part)
|
||
|
||
if norm_part in normalized_token:
|
||
return {"word": part, "type": "conflict_part"}
|
||
|
||
# LEMMA
|
||
for word_text in extract_words(text_without_urls):
|
||
word_lower = word_text.lower()
|
||
|
||
# если слово разрешено — пропускаем
|
||
if word_lower in whitelist:
|
||
continue
|
||
|
||
normalized_word = self.normalizer.normalize_full(word_text)
|
||
lemma = get_lemma(normalized_word)
|
||
|
||
if lemma in lemma_words:
|
||
return {"word": lemma, "type": "lemma"}
|
||
|
||
return None
|
||
|
||
@staticmethod
|
||
async def _handle_spam(
|
||
message: Message,
|
||
) -> None:
|
||
|
||
try:
|
||
await message.delete()
|
||
logger.info(f"Удалено сообщение: {message.text}")
|
||
except TelegramBadRequest:
|
||
return
|