You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
redditbot/redditbot.py

353 lines
16 KiB

import os
import time
import html
import json
import random
import logging
import asyncio
import tempfile
import functools
import mimetypes
import traceback
from itertools import zip_longest
from urllib.parse import urlparse, urlunparse
import yaml
import praw
import aiohttp
import aiocron
from bs4 import BeautifulSoup
from telethon import TelegramClient, events
from telethon.utils import chunks
with open('config.yaml') as file:
config_data = yaml.safe_load(file)
tg_api_id = config_data['telegram']['api_id']
tg_api_hash = config_data['telegram']['api_hash']
bot_token = config_data['telegram'].get('bot_token')
reddit_client_id = config_data['reddit']['client_id']
reddit_client_secret = config_data['reddit']['client_secret']
storage_chat = config_data['config'].get('storage_chat')
storage_msg_id = config_data['config'].get('storage_message_id')
_bkup_subreddits = config_data['config'].get('subreddits')
_send_to_chats = config_data['config']['send_to_chats']
send_to_chats = dict()
for chat in _send_to_chats:
subs = None
if isinstance(chat, dict):
subs = tuple(chat.values())[0]
chat = tuple(chat.keys())[0]
send_to_chats[chat] = subs
bot_admins = config_data['config']['bot_admins']
cron_duration = config_data['config']['cron_duration']
logging.basicConfig(level=logging.INFO)
async def main():
zws = '\u200b'
client = await TelegramClient('redditbot', tg_api_id, tg_api_hash).start(bot_token=bot_token)
client.parse_mode = 'html'
session = aiohttp.ClientSession()
reddit = praw.Reddit(client_id=reddit_client_id, client_secret=reddit_client_secret, user_agent='linux:redditbot:v1.0.0 (by /u/the_blank_x)')
try:
if storage_chat and storage_msg_id:
await (await client.get_messages(storage_chat, ids=storage_msg_id)).download_media('redditbot.json')
with open('redditbot.json') as file:
seen_posts = json.load(file)
if isinstance(seen_posts, list):
seen_posts = {'version': 0, 'chats': {'global': seen_posts}}
except Exception:
logging.exception('Loading JSON')
seen_posts = {'version': 0, 'chats': {'global': []}}
# chat dict: {chatid: [array of submission ids]}
async def write_seen_posts():
with open('redditbot.json', 'w') as file:
json.dump(seen_posts, file)
if storage_chat and storage_msg_id:
await client.edit_message(storage_chat, storage_msg_id, file='redditbot.json')
@aiocron.crontab(cron_duration)
async def start_post():
global_sp = seen_posts['chats']['global']
for chat in send_to_chats:
uses_bkupsub = False
subreddits = send_to_chats[chat]
if not subreddits:
subreddits = _bkup_subreddits
uses_bkupsub = True
chat = await client.get_peer_id(chat)
if str(chat) not in seen_posts['chats']:
seen_posts['chats'][str(chat)] = []
chat_sp = global_sp if uses_bkupsub else seen_posts['chats'][str(chat)]
while True:
random.shuffle(subreddits)
to_break = False
for subreddit_name in subreddits:
subreddit = reddit.subreddit(subreddit_name)
while True:
random_post = subreddit.random()
cpid = None
if random_post is None:
for submission in subreddit.hot():
cpid = getattr(submission, 'crosspost_parent', None)
if cpid:
cpid = cpid[3:]
if submission.id in chat_sp + global_sp or cpid in chat_sp + global_sp:
continue
random_post = submission
break
cpid = getattr(random_post, 'crosspost_parent', None)
if cpid:
cpid = cpid[3:]
if random_post.id in chat_sp + global_sp or cpid in chat_sp + global_sp:
continue
chat_sp.append(cpid or random_post.id)
print(random_post.id, random_post.shortlink)
to_break = True
break
if to_break:
break
try:
await _actual_start_post(random_post, [chat])
except Exception:
logging.exception(random_post.id)
for i in bot_admins:
await client.send_message(i, f'{random_post.id}\n{traceback.format_exc()}', parse_mode=None)
else:
break
await write_seen_posts()
async def _start_broadcast(text, file, chats):
uploaded_files = []
for i in file or []:
uploaded_files.append(await client.upload_file(i))
for chat in chats:
for i in chunks(zip_longest(text, uploaded_files), 10):
j, k = zip(*i)
if not any(k):
k = None
if not k and len(j) == 1:
j = j[0]
await client.send_message(chat, j, file=k, link_preview=False)
async def _download_file(filename, url):
print(url)
async with session.get(url) as resp:
with open(filename, 'wb') as file:
while True:
chunk = await resp.content.read(10)
if not chunk:
break
file.write(chunk)
async def _get_file_mimetype(filename):
mimetype = mimetypes.guess_type(filename)[0]
if not mimetype:
proc = await asyncio.create_subprocess_exec('file', '--brief', '--mime-type', filename, stdout=asyncio.subprocess.PIPE)
stdout, _ = await proc.communicate()
mimetype = stdout.decode().strip()
return mimetype or ''
async def _get_file_ext(filename):
proc = await asyncio.create_subprocess_exec('file', '--brief', '--extension', filename, stdout=asyncio.subprocess.PIPE)
stdout, _ = await proc.communicate()
ext = stdout.decode().strip().split('/', maxsplit=1)[0]
if not ext or ext == '???':
mimetype = await _get_file_mimetype(filename)
ext = mimetypes.guess_extension(mimetype) or '.bin'
if not ext.startswith('.'):
ext = '.' + ext
return ext
async def _actual_start_post(random_post, chats):
text = f'<a href="{random_post.shortlink}">{html.escape(random_post.title)}</a>'
cpid = getattr(random_post, 'crosspost_parent', None)
if cpid and getattr(random_post, 'crosspost_parent_list', None):
random_post = reddit.submission(cpid[3:])
text += f' (crosspost of <a href="{random_post.shortlink}">{html.escape(random_post.title)}</a>)'
if not random_post.is_self:
with tempfile.TemporaryDirectory() as tempdir:
url = random_post.url
filename = os.path.join(tempdir, str(time.time()))
files = [filename]
captions = [text]
if random_post.is_video:
ffmpeg_exists = any(True for i in os.environ.get('PATH', '').split(':') if os.path.exists(os.path.join(i, 'ffmpeg')))
reddit_video = random_post.secure_media['reddit_video']
for i in ('hls_url', 'dash_url'):
if not ffmpeg_exists:
continue
url = reddit_video.get(i)
if not url:
continue
print(url)
proc = await asyncio.create_subprocess_exec('ffmpeg', '-nostdin', '-y', '-i', url, '-c', 'copy', '-f', 'mp4', filename)
await proc.communicate()
if not proc.returncode:
url = None
break
else:
url = reddit_video['fallback_url']
elif getattr(random_post, 'is_gallery', None):
files = []
captions = []
for a, i in enumerate(random_post.media_metadata):
captions.append(f'{text}\n#{a + 1}')
filename = os.path.join(tempdir, str(time.time()))
i = random_post.media_metadata[i]
await _download_file(filename, i['s']['u'])
files.append(filename)
url = None
if url:
parsed = list(urlparse(url))
splitted = os.path.splitext(parsed[2])
domain = getattr(random_post, 'domain', parsed[1])
preview = getattr(random_post, 'preview', None)
if domain.endswith('imgur.com'):
parsed[1] = 'i.imgur.com'
if parsed[2].startswith('/a/'):
albumid = os.path.split(parsed[2])[1]
async with session.get(f'https://imgur.com/ajaxalbums/getimages/{albumid}/hit.json?all=true') as resp:
apidata = (await resp.json())['data']
if apidata['count'] == 1:
parsed[2] = apidata['images'][0]['hash'] + apidata['images'][0]['ext']
desc = apidata['images'][0]['description']
if desc:
captions[0] += '\n' + html.escape(desc)
else:
files = []
captions = []
for a, i in enumerate(apidata['images']):
to_append = f'#{a + 1}'
desc = i['description']
if desc:
to_append += ': ' + desc.strip()
caplength = 2047 - len(client.parse_mode.parse(text)[0])
captext = to_append[:caplength]
if len(captext) >= caplength:
captext = captext[:-1]
captext += ''
captions.append(text + '\n' + html.escape(captext))
filename = os.path.join(tempdir, str(time.time()))
await _download_file(filename, f'https://i.imgur.com/{i["hash"]}{i["ext"]}')
files.append(filename)
url = None
if splitted[1] == '.gifv':
parsed[2] = splitted[0] + '.mp4'
if url:
url = urlunparse(parsed)
elif domain == 'gfycat.com':
async with session.get(f'https://api.gfycat.com/v1/gfycats/{parsed[2]}') as resp:
apidata = await resp.json()
gfyitem = apidata.get('gfyItem')
if gfyitem:
url = gfyitem.get('mp4Url', url)
elif random_post.is_reddit_media_domain and splitted[1] == '.gif' and preview:
preview = preview['images'][0]['variants']
for i in ('mp4', 'gif'):
if i in preview:
url = preview[i]['source']['url']
break
if url:
await _download_file(filename, url)
mimetype = await _get_file_mimetype(filename)
if mimetype.startswith('image') and preview and preview['enabled']:
preview = preview['images'][0]
urls = [i['url'] for i in preview['resolutions']]
urls.append(preview['source']['url'])
urls.reverse()
for url in urls:
if os.path.getsize(filename) < 10000000:
break
await _download_file(filename, url)
ext = await _get_file_ext(filename)
if ext.startswith('.htm'):
with open(filename) as file:
soup = BeautifulSoup(file.read())
ptitle = soup.find(lambda tag: tag.name == 'meta' and tag.attrs.get('property') == 'og:title' and tag.attrs.get('content')) or soup.find('title')
if ptitle:
ptitle = ptitle.attrs.get('content', ptitle.text).strip()
pdesc = soup.find(lambda tag: tag.name == 'meta' and tag.attrs.get('property') == 'og:description' and tag.attrs.get('content')) or soup.find(lambda tag: tag.name == 'meta' and tag.attrs.get('name') == 'description' and tag.attrs.get('content'))
if pdesc:
pdesc = pdesc.attrs.get('content', pdesc.text).strip()
pimg = soup.find(lambda tag: tag.name == 'meta' and tag.attrs.get('property') == 'og:image' and tag.attrs.get('content'))
if pimg:
pimg = pimg.attrs.get('content', '').strip()
tat = f'{text}\n\nURL: '
if ptitle:
tat += f'<a href="{url}">{html.escape(ptitle)}</a>'
else:
tat += url
files = []
if pimg:
await _download_file(filename, pimg)
files.append(filename)
else:
tat = f'<a href="{url}">{zws}</a>{tat}'
if pdesc:
caplength = 2047 if pimg else 4095
caplength -= len(client.parse_mode.parse(tat)[0])
captext = pdesc[:caplength]
if len(captext) >= caplength:
captext = captext[:-1]
captext += ''
tat += '\n' + captext
captions = [tat]
for a, i in enumerate(files):
ext = await _get_file_ext(i)
os.rename(i, i + ext)
files[a] = i + ext
await _start_broadcast(captions, files, chats)
else:
if getattr(random_post, 'selftext', None):
caplength = 4094 - len(client.parse_mode.parse(text)[0])
text += '\n\n'
captext = random_post.selftext.strip()[:caplength]
if len(captext) >= caplength:
captext = captext[:-1]
captext += ''
text += html.escape(captext)
await _start_broadcast([text], None, chats)
def register(pattern):
def wrapper(func):
@functools.wraps(func)
@client.on(events.NewMessage(chats=bot_admins, pattern=pattern))
async def awrapper(e):
try:
await func(e)
except Exception:
await e.reply(traceback.format_exc(), parse_mode=None)
raise
return awrapper
return wrapper
@register('/(start|help)')
async def start_or_help(e):
await e.reply(('/start - /help\n'
'/help - /start\n'
'/poweroff - shuts down bot\n'
'/test <submission id> - tests sending submission'), parse_mode=None)
@register('/poweroff')
async def poweroff(e):
await e.reply('ok')
await e.client.disconnect()
@register('/test (.+)')
async def test_post(e):
await e.reply('ok')
post = reddit.submission(e.pattern_match.group(1))
await _actual_start_post(post, [e.chat_id])
# await start_post.func()
try:
await client.run_until_disconnected()
finally:
await session.close()
if __name__ == '__main__':
asyncio.run(main())