redditbot/redditbot.py

278 lines
12 KiB
Python

import os
import time
import html
import json
import random
import logging
import asyncio
import tempfile
import functools
import mimetypes
import traceback
from itertools import zip_longest
from urllib.parse import urlparse, urlunparse
import yaml
import praw
import aiohttp
import aiocron
from telethon import TelegramClient, events
from telethon.utils import chunks
with open('config.yaml') as file:
config_data = yaml.safe_load(file)
tg_api_id = config_data['telegram']['api_id']
tg_api_hash = config_data['telegram']['api_hash']
bot_token = config_data['telegram'].get('bot_token')
reddit_client_id = config_data['reddit']['client_id']
reddit_client_secret = config_data['reddit']['client_secret']
storage_chat = config_data['config'].get('storage_chat')
storage_msg_id = config_data['config'].get('storage_message_id')
send_to_chats = config_data['config']['send_to_chats']
subreddits = config_data['config']['subreddits']
bot_admins = config_data['config']['bot_admins']
cron_duration = config_data['config']['cron_duration']
logging.basicConfig(level=logging.INFO)
async def main():
client = await TelegramClient('redditbot', tg_api_id, tg_api_hash).start(bot_token=bot_token)
client.parse_mode = 'html'
session = aiohttp.ClientSession()
reddit = praw.Reddit(client_id=reddit_client_id, client_secret=reddit_client_secret, user_agent='linux:redditbot:v1.0.0 (by /u/the_blank_x)')
try:
if storage_chat and storage_msg_id:
await (await client.get_messages(storage_chat, ids=storage_msg_id)).download_media('redditbot.json')
with open('redditbot.json') as file:
seen_posts = json.load(file)
except Exception:
traceback.print_exc()
seen_posts = []
async def write_seen_posts():
with open('redditbot.json', 'w') as file:
json.dump(seen_posts, file)
if storage_chat and storage_msg_id:
await client.edit_message(storage_chat, storage_msg_id, file='redditbot.json')
@aiocron.crontab(cron_duration)
async def start_post():
while True:
random.shuffle(subreddits)
to_break = False
for subreddit_name in subreddits:
subreddit = reddit.subreddit(subreddit_name)
while True:
random_post = subreddit.random()
cpid = None
if random_post is None:
for submission in subreddit.hot():
cpid = getattr(submission, 'crosspost_parent', None)
if cpid:
cpid = cpid[3:]
if submission.id in seen_posts or cpid in seen_posts:
continue
random_post = submission
break
cpid = getattr(random_post, 'crosspost_parent', None)
if cpid:
cpid = cpid[3:]
if random_post.id in seen_posts or cpid in seen_posts:
continue
seen_posts.append(cpid or random_post.id)
print(random_post.id, random_post.shortlink)
to_break = True
break
if to_break:
break
to_break = False
for _ in range(5):
try:
await _actual_start_post(random_post, send_to_chats)
except Exception:
traceback.print_exc()
for i in bot_admins:
await client.send_message(i, f'{random_post.id}\n{traceback.format_exc()}', parse_mode=None)
else:
to_break = True
break
if to_break:
break
await write_seen_posts()
async def _start_broadcast(text, file, chats):
uploaded_files = []
for i in file or []:
uploaded_files.append(await client.upload_file(i))
for chat in chats:
for i in chunks(zip_longest(text, uploaded_files), 10):
j, k = zip(*i)
if not any(k):
k = None
if not k and len(j) == 1:
j = j[0]
await client.send_message(chat, j, file=k)
async def _download_file(filename, url):
print(url)
async with session.get(url) as resp:
with open(filename, 'wb') as file:
while True:
chunk = await resp.content.read(10)
if not chunk:
break
file.write(chunk)
async def _get_file_ext(filename):
proc = await asyncio.create_subprocess_exec('file', '--brief', '--extension', filename, stdout=asyncio.subprocess.PIPE)
stdout, _ = await proc.communicate()
ext = stdout.decode().strip().split('/', maxsplit=1)[0]
if not ext or ext == '???':
mimetype = mimetypes.guess_type(filename)[0]
if not mimetype:
proc = await asyncio.create_subprocess_exec('file', '--brief', '--mime-type', filename, stdout=asyncio.subprocess.PIPE)
stdout, _ = await proc.communicate()
mimetype = stdout.decode().strip()
ext = mimetypes.guess_extension(mimetype or '') or '.bin'
if not ext.startswith('.'):
ext = '.' + ext
return ext
async def _actual_start_post(random_post, chats):
text = f'<a href="{random_post.shortlink}">{html.escape(random_post.title)}</a>'
cpid = getattr(random_post, 'crosspost_parent', None)
if cpid:
random_post = reddit.submission(cpid[3:])
text += f' (crosspost of <a href="{random_post.shortlink}">{html.escape(random_post.title)}</a>)'
if not random_post.is_self:
with tempfile.TemporaryDirectory() as tempdir:
url = random_post.url
filename = os.path.join(tempdir, str(time.time()))
files = [filename]
captions = [text]
if random_post.is_video:
ffmpeg_exists = any(True for i in os.environ.get('PATH', '').split(':') if os.path.exists(os.path.join(i, 'ffmpeg')))
reddit_video = random_post.secure_media['reddit_video']
for i in ('hls_url', 'dash_url'):
if not ffmpeg_exists:
continue
url = reddit_video.get(i)
if not url:
continue
print(url)
proc = await asyncio.create_subprocess_exec('ffmpeg', '-nostdin', '-y', '-i', url, '-c', 'copy', '-f', 'mp4', filename)
await proc.communicate()
if not proc.returncode:
url = None
break
else:
url = reddit_video['fallback_url']
elif getattr(random_post, 'is_gallery', None):
files = []
captions = []
for a, i in enumerate(random_post.media_metadata):
captions.append(f'{text}\n#{a + 1}')
filename = os.path.join(tempdir, str(time.time()))
i = random_post.media_metadata[i]
await _download_file(filename, i['s']['u'])
files.append(filename)
url = None
if url:
parsed = list(urlparse(url))
splitted = os.path.splitext(parsed[2])
domain = getattr(random_post, 'domain', parsed[1])
preview = getattr(random_post, 'preview', None)
if domain.endswith('imgur.com'):
parsed[1] = 'i.imgur.com'
if parsed[2].startswith('/a/'):
albumid = os.path.split(parsed[2])[1]
async with session.get(f'https://imgur.com/ajaxalbums/getimages/{albumid}/hit.json?all=true') as resp:
apidata = (await resp.json())['data']
if apidata['count'] == 1:
parsed[2] = apidata['images'][0]['hash'] + apidata['images'][0]['ext']
desc = apidata['images'][0]['description']
if desc:
captions[0] += '\n' + html.escape(desc)
else:
files = []
captions = []
for a, i in enumerate(apidata['images']):
to_append = f'{text}\n#{a + 1}'
desc = i['description']
if desc:
to_append += ': ' + html.escape(desc)
captions.append(to_append)
filename = os.path.join(tempdir, str(time.time()))
await _download_file(filename, f'https://i.imgur.com/{i["hash"]}{i["ext"]}')
files.append(filename)
url = None
if splitted[1] == '.gifv':
parsed[2] = splitted[0] + '.mp4'
if url:
url = urlunparse(parsed)
elif domain == 'gfycat.com':
async with session.get(f'https://api.gfycat.com/v1/gfycats/{parsed[2]}') as resp:
apidata = await resp.json()
gfyitem = apidata.get('gfyItem')
if gfyitem:
url = gfyitem.get('mp4Url', url)
elif random_post.is_reddit_media_domain and splitted[1] == '.gif' and preview:
preview = preview['images'][0]['variants']
for i in ('mp4', 'gif'):
if i in preview:
url = preview[i]['source']['url']
break
elif random_post.is_reddit_media_domain and preview and preview['enabled'] and not random_post.is_video:
url = preview['images'][0]['source']['url']
if url:
await _download_file(filename, url)
for a, i in enumerate(files):
ext = await _get_file_ext(i)
os.rename(i, i + ext)
files[a] = i + ext
await _start_broadcast(captions, files, chats)
else:
await _start_broadcast([text], None, chats)
def register(pattern):
def wrapper(func):
@functools.wraps(func)
@client.on(events.NewMessage(chats=bot_admins, pattern=pattern))
async def awrapper(e):
try:
await func(e)
except Exception:
await e.reply(traceback.format_exc(), parse_mode=None)
raise
return awrapper
return wrapper
@register('/(start|help)')
async def start_or_help(e):
await e.reply(('/start - /help\n'
'/help - /start\n'
'/poweroff - shuts down bot\n'
'/test <submission id> - tests sending submission'), parse_mode=None)
@register('/poweroff')
async def poweroff(e):
await e.reply('ok')
await e.client.disconnect()
@register('/test (.+)')
async def test_post(e):
await e.reply('ok')
post = reddit.submission(e.pattern_match.group(1))
await _actual_start_post(post, [e.chat_id])
# await start_post.func()
try:
await client.run_until_disconnected()
finally:
await session.close()
if __name__ == '__main__':
asyncio.run(main())