import os import html import time import logging import asyncio import traceback import functools import mimetypes import yaml import aiohttp from collections import deque from datetime import timedelta from pyrogram import Client, StopPropagation, ContinuePropagation from pyrogram.types import Chat, User from pyrogram.parser import parser from pyrogram.errors.exceptions.bad_request_400 import PeerIdInvalid, ChannelInvalid logging.basicConfig(level=logging.INFO) with open('config.yaml') as config: config = yaml.safe_load(config) loop = asyncio.get_event_loop() help_dict = dict() log_ring = deque(maxlen=config['config'].get('log_ring_maxlen', 69420)) apps = [] app_user_ids = dict() # this code here exists because i can't be fucked class Parser(parser.Parser): async def parse(self, text, mode): if mode == 'through': return text return await super().parse(text, mode) for session_name in config['config']['sessions']: app = Client(session_name, api_id=config['telegram']['api_id'], api_hash=config['telegram']['api_hash'], plugins={'root': os.path.join(__package__, 'plugins')}, parse_mode='html', workdir='sessions', test_mode=config['telegram'].get('use_test_servers', False)) app.parser = Parser(app) apps.append(app) slave = Client('sukuinote-slave', api_id=config['telegram']['api_id'], api_hash=config['telegram']['api_hash'], plugins={'root': os.path.join(__package__, 'slave-plugins')}, parse_mode='html', bot_token=config['telegram']['slave_bot_token'], workdir='sessions', test_mode=config['telegram'].get('use_test_servers', False)) slave.parser = Parser(slave) session = aiohttp.ClientSession() async def get_entity(client, entity): entity_client = client if not isinstance(entity, Chat): try: entity = int(entity) except ValueError: pass except TypeError: entity = entity.id try: entity = await client.get_chat(entity) except (PeerIdInvalid, ChannelInvalid): for app in apps: if app != client: try: entity = await app.get_chat(entity) except (PeerIdInvalid, ChannelInvalid): pass else: entity_client = app break else: entity = await slave.get_chat(entity) entity_client = slave return entity, entity_client async def get_user(client, entity): entity_client = client if not isinstance(entity, User): try: entity = int(entity) except ValueError: pass except TypeError: entity = entity.id try: entity = await client.get_users(entity) except PeerIdInvalid: for app in apps: if app != client: try: entity = await app.get_users(entity) except PeerIdInvalid: pass else: entity_client = app break else: entity = await slave.get_users(entity) entity_client = slave return entity, entity_client def log_errors(func): @functools.wraps(func) async def wrapper(client, *args): try: await func(client, *args) except (StopPropagation, ContinuePropagation): raise except BaseException: tb = traceback.format_exc() try: await slave.send_message(config['config']['log_chat'], f'Exception occured in {func.__name__}\n\n{tb}', parse_mode=None) except BaseException: logging.exception('Failed to log exception for %s as slave', func.__name__) tb = traceback.format_exc() for app in apps: try: await app.send_message(config['config']['log_chat'], f'Exception occured in {func.__name__}\n\n{tb}', parse_mode=None) except BaseException: logging.exception('Failed to log exception for %s as app', func.__name__) tb = traceback.format_exc() else: break raise raise return wrapper def public_log_errors(func): @functools.wraps(func) async def wrapper(client, message): try: await func(client, message) except (StopPropagation, ContinuePropagation): raise except BaseException: await message.reply_text(traceback.format_exc(), parse_mode=None) raise return wrapper # https://stackoverflow.com/a/49361727 def format_bytes(size): size = int(size) # 2**10 = 1024 power = 1024 n = 0 power_labels = {0 : '', 1: 'K', 2: 'M', 3: 'G', 4: 'T'} while size > power: size /= power n += 1 return f"{size:.2f} {power_labels[n]+'B'}" # https://stackoverflow.com/a/34325723 def return_progress_string(current, total): filled_length = int(30 * current // total) return '[' + '=' * filled_length + ' ' * (30 - filled_length) + ']' # https://stackoverflow.com/a/852718 # https://stackoverflow.com/a/775095 def calculate_eta(current, total, start_time): if not current: return '00:00:00' end_time = time.time() elapsed_time = end_time - start_time seconds = (elapsed_time * (total / current)) - elapsed_time thing = ''.join(str(timedelta(seconds=seconds)).split('.')[:-1]).split(', ') thing[-1] = thing[-1].rjust(8, '0') return ', '.join(thing) progress_callback_data = dict() async def progress_callback(current, total, reply, text, upload): message_identifier = (reply.chat.id, reply.message_id) last_edit_time, prevtext, start_time = progress_callback_data.get(message_identifier, (0, None, time.time())) if current == total: try: progress_callback_data.pop(message_identifier) except KeyError: pass elif (time.time() - last_edit_time) > 1: handle = 'Upload' if upload else 'Download' if last_edit_time: speed = format_bytes((total - current) / (time.time() - start_time)) else: speed = '0 B' text = f'''{text} {return_progress_string(current, total)} Total Size: {format_bytes(total)} {handle}ed Size: {format_bytes(current)} {handle} Speed: {speed}/s ETA: {calculate_eta(current, total, start_time)}''' if prevtext != text: await reply.edit_text(text) prevtext = text last_edit_time = time.time() progress_callback_data[message_identifier] = last_edit_time, prevtext, start_time async def get_file_mimetype(filename): mimetype = mimetypes.guess_type(filename)[0] if not mimetype: proc = await asyncio.create_subprocess_exec('file', '--brief', '--mime-type', filename, stdout=asyncio.subprocess.PIPE) stdout, _ = await proc.communicate() mimetype = stdout.decode().strip() return mimetype or '' async def get_file_ext(filename): proc = await asyncio.create_subprocess_exec('file', '--brief', '--extension', filename, stdout=asyncio.subprocess.PIPE) stdout, _ = await proc.communicate() ext = stdout.decode().strip().split('/', maxsplit=1)[0] if not ext or ext == '???': mimetype = await get_file_mimetype(filename) ext = mimetypes.guess_extension(mimetype) or '.bin' if not ext.startswith('.'): ext = '.' + ext return ext