From 59a8530cbe3c2ad9dc1c7c9237c9e73cc240002e Mon Sep 17 00:00:00 2001 From: klemek Date: Thu, 30 Apr 2020 08:55:12 +0200 Subject: [PATCH] checking image file size beforehand --- discord_bot/__main__.py | 46 +++++++++++++++++++--------------------- meme_otron/meme_otron.py | 13 ++++++++++-- 2 files changed, 33 insertions(+), 26 deletions(-) diff --git a/discord_bot/__main__.py b/discord_bot/__main__.py index c208027..572aa6a 100644 --- a/discord_bot/__main__.py +++ b/discord_bot/__main__.py @@ -121,7 +121,8 @@ async def on_message(message: discord.Message): if len(message.attachments) > 0: input_data = await message.attachments[0].read() - img, errors = meme_otron.compute(*args, left_wmark_text=left_wmark_text, input_data=input_data) + img, errors = meme_otron.compute(*args, left_wmark_text=left_wmark_text, + input_data=input_data, max_file_size=8 * 1024 * 1024) if len(errors) > 0: response = ":warning:" for err in errors: @@ -135,29 +136,26 @@ async def on_message(message: discord.Message): else: with tempfile.NamedTemporaryFile(delete=False) as output: img.save(output, format="JPEG") - if os.stat(output.name).st_size > 8 * 1024 * 1024: # 8MB - await message.channel.send(":warning:\nOutput image is too big to be sent by discord") - else: - response = None - meme_id = utils.sanitize_input(args[0]) - if len(args) == 1 and meme_id not in ["image", "text"]: - meme = meme_db.get_meme(meme_id) - response = f"Template `{meme.id}`:" - if len(meme.aliases) > 0: - response += f"\n- Aliases: `{'`, `'.join(meme.aliases)}`" - if meme.info is not None: - response += f"\n- More info: <{meme.info}>" - response += f"\n- Use:" \ - f"\n```{meme.id} \"" + \ - "\" \"".join([f"text {i + 1}" for i in range(meme.texts_len)]) + \ - "\"```" - elif not is_direct: - response = f"A meme by {message.author.mention}:" - if message_id not in SENT: - SENT[message_id] = [] - response = await message.channel.send(response, - file=discord.File(filename="meme.jpg", fp=output.name)) - SENT[message_id] += [response] + response = None + meme_id = utils.sanitize_input(args[0]) + if len(args) == 1 and meme_id not in ["image", "text"]: + meme = meme_db.get_meme(meme_id) + response = f"Template `{meme.id}`:" + if len(meme.aliases) > 0: + response += f"\n- Aliases: `{'`, `'.join(meme.aliases)}`" + if meme.info is not None: + response += f"\n- More info: <{meme.info}>" + response += f"\n- Use:" \ + f"\n```{meme.id} \"" + \ + "\" \"".join([f"text {i + 1}" for i in range(meme.texts_len)]) + \ + "\"```" + elif not is_direct: + response = f"A meme by {message.author.mention}:" + if message_id not in SENT: + SENT[message_id] = [] + response = await message.channel.send(response, + file=discord.File(filename="meme.jpg", fp=output.name)) + SENT[message_id] += [response] try: os.remove(output.name) except PermissionError: diff --git a/meme_otron/meme_otron.py b/meme_otron/meme_otron.py index 17549a9..41f893c 100644 --- a/meme_otron/meme_otron.py +++ b/meme_otron/meme_otron.py @@ -1,6 +1,7 @@ from typing import Optional, Tuple, List import re from PIL import Image +from io import BytesIO from .types import Text, Pos from . import img_factory @@ -31,6 +32,7 @@ simple_text.y_range = [0.2, 0.8] def compute(*args: str, input_data: Optional[bytes] = None, wmark: bool = True, left_wmark_text: Optional[str] = None, + max_file_size: Optional[int] = None, debug: bool = False) -> Tuple[Optional[Image.Image], List[str]]: if len(args) < 1: return None, ['Not enough arguments'] @@ -39,7 +41,7 @@ def compute(*args: str, input_data: Optional[bytes] = None, images = [] errors = [] for part in parts: - img, err = compute_part(*part, input_data=input_data, debug=debug) + img, err = compute_part(*part, input_data=input_data, max_file_size=max_file_size, debug=debug) if img is not None: images += [img] else: @@ -56,10 +58,17 @@ def compute(*args: str, input_data: Optional[bytes] = None, watermarks += [left_wmark.variant(left_wmark_text)] output_image = img_factory.apply_texts(output_image, watermarks, debug=debug) + if max_file_size is not None: + img_file = BytesIO() + output_image.save(img_file, 'jpg') + if img_file.tell() > max_file_size: + return None, ['Output image too big'] + return output_image, errors def compute_part(*args: str, input_data: Optional[bytes] = None, + max_file_size: Optional[int] = None, debug: bool = False) -> Tuple[Optional[Image.Image], Optional[str]]: meme_id = utils.sanitize_input(args[0]) @@ -73,7 +82,7 @@ def compute_part(*args: str, input_data: Optional[bytes] = None, if len(args) <= 1: return None, 'Image: received no input data nor URL' else: - input_data, err = utils.read_web(args[1]) + input_data, err = utils.read_web(args[1], max_file_size=max_file_size) if input_data is None: return None, 'Image: ' + err img = img_factory.build_image_only(input_data)