206 lines
7.5 KiB
Python
206 lines
7.5 KiB
Python
import os
|
|
import html
|
|
import time
|
|
import logging
|
|
import asyncio
|
|
import traceback
|
|
import functools
|
|
import mimetypes
|
|
import yaml
|
|
import aiohttp
|
|
from collections import deque
|
|
from datetime import timedelta
|
|
from pyrogram import Client, StopPropagation, ContinuePropagation
|
|
from pyrogram.types import Chat, User
|
|
from pyrogram.parser import parser
|
|
from pyrogram.errors.exceptions.bad_request_400 import PeerIdInvalid, ChannelInvalid
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
with open('config.yaml') as config:
|
|
config = yaml.safe_load(config)
|
|
loop = asyncio.get_event_loop()
|
|
help_dict = dict()
|
|
log_ring = deque(maxlen=config['config'].get('log_ring_maxlen', 69420))
|
|
|
|
apps = []
|
|
app_user_ids = dict()
|
|
# this code here exists because i can't be fucked
|
|
class Parser(parser.Parser):
|
|
async def parse(self, text, mode):
|
|
if mode == 'through':
|
|
return text
|
|
return await super().parse(text, mode)
|
|
for session_name in config['config']['sessions']:
|
|
app = Client(session_name, api_id=config['telegram']['api_id'], api_hash=config['telegram']['api_hash'], plugins={'root': os.path.join(__package__, 'plugins')}, parse_mode='html', workdir='sessions', test_mode=config['telegram'].get('use_test_servers', False))
|
|
app.parser = Parser(app)
|
|
apps.append(app)
|
|
slave = Client('sukuinote-slave', api_id=config['telegram']['api_id'], api_hash=config['telegram']['api_hash'], plugins={'root': os.path.join(__package__, 'slave-plugins')}, parse_mode='html', bot_token=config['telegram']['slave_bot_token'], workdir='sessions', test_mode=config['telegram'].get('use_test_servers', False))
|
|
slave.parser = Parser(slave)
|
|
session = aiohttp.ClientSession()
|
|
|
|
async def get_entity(client, entity):
|
|
entity_client = client
|
|
if not isinstance(entity, Chat):
|
|
try:
|
|
entity = int(entity)
|
|
except ValueError:
|
|
pass
|
|
except TypeError:
|
|
entity = entity.id
|
|
try:
|
|
entity = await client.get_chat(entity)
|
|
except (PeerIdInvalid, ChannelInvalid):
|
|
for app in apps:
|
|
if app != client:
|
|
try:
|
|
entity = await app.get_chat(entity)
|
|
except (PeerIdInvalid, ChannelInvalid):
|
|
pass
|
|
else:
|
|
entity_client = app
|
|
break
|
|
else:
|
|
entity = await slave.get_chat(entity)
|
|
entity_client = slave
|
|
return entity, entity_client
|
|
|
|
async def get_user(client, entity):
|
|
entity_client = client
|
|
if not isinstance(entity, User):
|
|
try:
|
|
entity = int(entity)
|
|
except ValueError:
|
|
pass
|
|
except TypeError:
|
|
entity = entity.id
|
|
try:
|
|
entity = await client.get_users(entity)
|
|
except PeerIdInvalid:
|
|
for app in apps:
|
|
if app != client:
|
|
try:
|
|
entity = await app.get_users(entity)
|
|
except PeerIdInvalid:
|
|
pass
|
|
else:
|
|
entity_client = app
|
|
break
|
|
else:
|
|
entity = await slave.get_users(entity)
|
|
entity_client = slave
|
|
return entity, entity_client
|
|
|
|
def log_errors(func):
|
|
@functools.wraps(func)
|
|
async def wrapper(client, *args):
|
|
try:
|
|
await func(client, *args)
|
|
except (StopPropagation, ContinuePropagation):
|
|
raise
|
|
except BaseException:
|
|
tb = traceback.format_exc()
|
|
try:
|
|
await slave.send_message(config['config']['log_chat'], f'Exception occured in {func.__name__}\n\n{tb}', parse_mode=None)
|
|
except BaseException:
|
|
logging.exception('Failed to log exception for %s as slave', func.__name__)
|
|
tb = traceback.format_exc()
|
|
for app in apps:
|
|
try:
|
|
await app.send_message(config['config']['log_chat'], f'Exception occured in {func.__name__}\n\n{tb}', parse_mode=None)
|
|
except BaseException:
|
|
logging.exception('Failed to log exception for %s as app', func.__name__)
|
|
tb = traceback.format_exc()
|
|
else:
|
|
break
|
|
raise
|
|
raise
|
|
return wrapper
|
|
|
|
def public_log_errors(func):
|
|
@functools.wraps(func)
|
|
async def wrapper(client, message):
|
|
try:
|
|
await func(client, message)
|
|
except (StopPropagation, ContinuePropagation):
|
|
raise
|
|
except BaseException:
|
|
await message.reply_text(traceback.format_exc(), parse_mode=None)
|
|
raise
|
|
return wrapper
|
|
|
|
# https://stackoverflow.com/a/49361727
|
|
def format_bytes(size):
|
|
size = int(size)
|
|
# 2**10 = 1024
|
|
power = 1024
|
|
n = 0
|
|
power_labels = {0 : '', 1: 'K', 2: 'M', 3: 'G', 4: 'T'}
|
|
while size > power:
|
|
size /= power
|
|
n += 1
|
|
return f"{size:.2f} {power_labels[n]+'B'}"
|
|
|
|
# https://stackoverflow.com/a/34325723
|
|
def return_progress_string(current, total):
|
|
filled_length = int(30 * current // total)
|
|
return '[' + '=' * filled_length + ' ' * (30 - filled_length) + ']'
|
|
|
|
# https://stackoverflow.com/a/852718
|
|
# https://stackoverflow.com/a/775095
|
|
def calculate_eta(current, total, start_time):
|
|
if not current:
|
|
return '00:00:00'
|
|
end_time = time.time()
|
|
elapsed_time = end_time - start_time
|
|
seconds = (elapsed_time * (total / current)) - elapsed_time
|
|
thing = ''.join(str(timedelta(seconds=seconds)).split('.')[:-1]).split(', ')
|
|
thing[-1] = thing[-1].rjust(8, '0')
|
|
return ', '.join(thing)
|
|
|
|
progress_callback_data = dict()
|
|
async def progress_callback(current, total, reply, text, upload):
|
|
message_identifier = (reply.chat.id, reply.message_id)
|
|
last_edit_time, prevtext, start_time = progress_callback_data.get(message_identifier, (0, None, time.time()))
|
|
if current == total:
|
|
try:
|
|
progress_callback_data.pop(message_identifier)
|
|
except KeyError:
|
|
pass
|
|
elif (time.time() - last_edit_time) > 1:
|
|
handle = 'Upload' if upload else 'Download'
|
|
if last_edit_time:
|
|
speed = format_bytes((total - current) / (time.time() - start_time))
|
|
else:
|
|
speed = '0 B'
|
|
text = f'''{text}
|
|
<code>{return_progress_string(current, total)}</code>
|
|
|
|
<b>Total Size:</b> {format_bytes(total)}
|
|
<b>{handle}ed Size:</b> {format_bytes(current)}
|
|
<b>{handle} Speed:</b> {speed}/s
|
|
<b>ETA:</b> {calculate_eta(current, total, start_time)}'''
|
|
if prevtext != text:
|
|
await reply.edit_text(text)
|
|
prevtext = text
|
|
last_edit_time = time.time()
|
|
progress_callback_data[message_identifier] = last_edit_time, prevtext, start_time
|
|
|
|
async def get_file_mimetype(filename):
|
|
mimetype = mimetypes.guess_type(filename)[0]
|
|
if not mimetype:
|
|
proc = await asyncio.create_subprocess_exec('file', '--brief', '--mime-type', filename, stdout=asyncio.subprocess.PIPE)
|
|
stdout, _ = await proc.communicate()
|
|
mimetype = stdout.decode().strip()
|
|
return mimetype or ''
|
|
|
|
async def get_file_ext(filename):
|
|
proc = await asyncio.create_subprocess_exec('file', '--brief', '--extension', filename, stdout=asyncio.subprocess.PIPE)
|
|
stdout, _ = await proc.communicate()
|
|
ext = stdout.decode().strip().split('/', maxsplit=1)[0]
|
|
if not ext or ext == '???':
|
|
mimetype = await get_file_mimetype(filename)
|
|
ext = mimetypes.guess_extension(mimetype) or '.bin'
|
|
if not ext.startswith('.'):
|
|
ext = '.' + ext
|
|
return ext
|