diff --git a/discord_bot/__main__.py b/discord_bot/__main__.py index 4264e13..49509e8 100644 --- a/discord_bot/__main__.py +++ b/discord_bot/__main__.py @@ -116,18 +116,12 @@ async def on_message(message: discord.Message): if len(args) > 1 and message.author.display_name is not None: left_wmark_text = f"By {message.author.display_name}" logging.info(args[0]) - meme_id = re.sub(r'[^A-Za-z0-9 _]', "", args[0]).strip() - args[0] = meme_id - img = meme_otron.compute(*args, left_wmark_text=left_wmark_text) - if img is None: - if len(meme_id) == 0: - response = f":warning: Template not found\n" - else: - hint = meme_db.find_nearest(meme_id) - response = f":warning: Template `{meme_id}` not found\n" - if hint is not None: - response += f"Did you mean `{hint}`?\n" - response += f"You can find a more detailed help and a list of templates at:\n" \ + img, errors = meme_otron.compute(*args, left_wmark_text=left_wmark_text) + if len(errors) > 0: + response = ":warning:" + for err in errors: + response += "\n" + err.replace("'", "`").replace("`` ", "") + response += f"\nYou can find a more detailed help and a list of templates at:\n" \ f"<{DOC_URL}>" if len(response) >= 2000: await message.channel.send(f"{message.author.mention} ... really?") @@ -138,7 +132,7 @@ async def on_message(message: discord.Message): img.save(output, format="JPEG") response = None if len(args) == 1: - meme = meme_db.get_meme(meme_id) + meme = meme_db.get_meme(utils.sanitize_input(args[0])) response = f"Template `{meme.id}`:" if len(meme.aliases) > 0: response += f"\n- Aliases: `{'`, `'.join(meme.aliases)}`" diff --git a/meme_otron/__main__.py b/meme_otron/__main__.py index 29120e4..1fcdd25 100644 --- a/meme_otron/__main__.py +++ b/meme_otron/__main__.py @@ -20,11 +20,10 @@ if __name__ == "__main__": sys.exit(1) else: output_file = utils.read_argument(sys.argv, "-o", "--output", valued=True, delete=True) - img = meme_otron.compute(*sys.argv[1:]) + img, errors = meme_otron.compute(*sys.argv[1:]) + for err in errors: + print(err, file=sys.stderr) if img is None: - proposal = meme_db.find_nearest(sys.argv[1]) - if proposal is not None: - print(f"Did you mean '{proposal}'?", file=sys.stderr) sys.exit(1) if output_file is None: with os.fdopen(os.dup(sys.stdout.fileno())) as output: diff --git a/meme_otron/meme_otron.py b/meme_otron/meme_otron.py index 1291cd3..8abae79 100644 --- a/meme_otron/meme_otron.py +++ b/meme_otron/meme_otron.py @@ -1,6 +1,5 @@ -import logging -from typing import Optional - +from typing import Optional, Tuple, List +import re from PIL import Image from .types import Text, Pos @@ -8,8 +7,6 @@ from . import img_factory from . import meme_db from . import utils -logger = logging.getLogger("meme_otron") - right_wmark = Text("Made with meme-otron") right_wmark.position = Pos.SE right_wmark.fill = (128, 128, 128, 128) @@ -32,19 +29,23 @@ simple_text.x_range = [0.01, 0.99] simple_text.y_range = [0.2, 0.8] -def compute(*args: str, left_wmark_text: Optional[str] = None, debug: bool = False) -> Optional[Image.Image]: +def compute(*args: str, left_wmark_text: Optional[str] = None, + debug: bool = False) -> Tuple[Optional[Image.Image], List[str]]: if len(args) < 1: - return None + return None, ['Not enough arguments'] parts = utils.split_arguments(args, "-") images = [] + errors = [] for part in parts: - img = compute_part(*part, debug=debug) + img, err = compute_part(*part, debug=debug) if img is not None: images += [img] + else: + errors += [err] if len(images) == 0: - return None + return None, errors output_image = img_factory.compose_image(images) @@ -53,24 +54,27 @@ def compute(*args: str, left_wmark_text: Optional[str] = None, debug: bool = Fal watermarks += [left_wmark.variant(left_wmark_text)] output_image = img_factory.apply_texts(output_image, watermarks, debug=debug) - return output_image + return output_image, errors -def compute_part(*args: str, debug: bool = False) -> Optional[Image.Image]: - meme_id = args[0].lower().strip() +def compute_part(*args: str, debug: bool = False) -> Tuple[Optional[Image.Image], Optional[str]]: + meme_id = utils.sanitize_input(args[0]) if meme_id == "text": if len(args) < 2: - return None + return None, 'Text: not enough arguments' texts = [simple_text.variant(arg) for arg in args[1:]] - return img_factory.build_text_only(texts, debug=debug) + return img_factory.build_text_only(texts, debug=debug), None elif meme_id == "image": - return None + return None, 'Image: not yet implemented' else: meme = meme_db.get_meme(meme_id) if meme is None: - logger.warning(f"Meme template '{meme_id}' not found") - return None + error = f"Template: '{meme_id}' not found." + proposal = meme_db.find_nearest(meme_id) + if proposal is not None: + error += f" Did you mean '{proposal}'?" + return None, error if len(args) > 1: c = 0 for i in range(len(meme.texts)): @@ -82,4 +86,4 @@ def compute_part(*args: str, debug: bool = False) -> Optional[Image.Image]: c += 1 else: meme.texts[i].text = meme.texts[meme.texts[i].text_ref].text - return img_factory.build_from_template(meme.template, meme.texts, debug=debug) + return img_factory.build_from_template(meme.template, meme.texts, debug=debug), None diff --git a/meme_otron/utils.py b/meme_otron/utils.py index 2d133d9..310b75c 100644 --- a/meme_otron/utils.py +++ b/meme_otron/utils.py @@ -142,6 +142,10 @@ def find_nearest(word: str, wlist: List[str], threshold: int = 5) -> Optional[st return found[2] +def sanitize_input(src: str) -> str: + return re.sub(r'[^A-Za-z0-9 _]', "", src.lower().strip()) + + # endregion # region format utils diff --git a/tests/unit/meme_otron/test_utils.py b/tests/unit/meme_otron/test_utils.py index 27c695a..eed3961 100644 --- a/tests/unit/meme_otron/test_utils.py +++ b/tests/unit/meme_otron/test_utils.py @@ -93,6 +93,9 @@ class TestUtilsLang(TestCase): self.assertIsNone(utils.find_nearest("unknown", ["test", "example", "what"], threshold=2)) self.assertEqual("test", utils.find_nearest("unknown", ["test", "example", "what"], threshold=200)) + def test_sanitize_input(self): + self.assertEqual("", utils.sanitize_input("")) + self.assertEqual("a b_c", utils.sanitize_input(" A+=¤$ bé_cè:* ")) class TestUtilsArgs(TestCase): def test_parse_arguments(self):