From fa89302e54d0fb15bc47a26cd17de7891f1cdb0f Mon Sep 17 00:00:00 2001 From: jsl12 <32917998+jsl12@users.noreply.github.com> Date: Sun, 23 Jan 2022 13:27:35 -0600 Subject: [PATCH] handling reactions again --- kwaylon/data.py | 65 +++++++++++++++++++++++++--------------------- kwaylon/kwaylon.py | 49 ++++++++++++++++------------------ kwaylon/msg.py | 24 ++++++++++------- main.py | 19 +++++++------- 4 files changed, 81 insertions(+), 76 deletions(-) diff --git a/kwaylon/data.py b/kwaylon/data.py index fedb9a9..93ac4d7 100644 --- a/kwaylon/data.py +++ b/kwaylon/data.py @@ -8,11 +8,13 @@ import pandas as pd from nextcord import Client, Message from .msg import LOGGER, reaction_df +from .msg import reaction_dict LOGGER = logging.getLogger(__name__) class MsgData: + """Wrapper class to manage saving and loading the DataFrame of reactions""" db_path: Path msgs: pd.DataFrame reactions: pd.DataFrame @@ -22,6 +24,9 @@ class MsgData: self.lock = asyncio.Lock() self.db_path: Path = Path(path) if isinstance(path, str) else path + def __repr__(self): + return f'<{__name__}.{self.__class__.__name__} with {self.reactions.shape[0]} reactions>' + @property def sql_context(self): return sqlite3.connect(self.db_path) @@ -31,16 +36,28 @@ class MsgData: with self.sql_context as con: LOGGER.info(f'Opened {self.db_path.name}') try: - self.reactions = pd.read_sql('select * from reactions', con=con).reset_index() - self.reactions['datetime'] = pd.to_datetime(self.reactions['datetime']).dt.tz_convert(local_tz) + self.reactions = pd.read_sql('select * from reactions', con=con).reset_index(drop=True) except: LOGGER.warning(f'failed to read reactions from: {self.db_path.resolve()}') else: LOGGER.info(f'read {self.reactions.shape[0]:,} reactions') + self.reactions['datetime'] = pd.to_datetime(self.reactions['datetime']) + LOGGER.info(f"'datetime' dtype: {self.reactions['datetime'].dtype}") + LOGGER.info(f"{self.reactions['datetime'].values[:3]}...") + + # try: + # self.reactions['datetime'] = pd.to_datetime(self.reactions['datetime']).dt.tz_convert(local_tz) + # except Exception as e: + # LOGGER.exception(e) + # try: + # self.reactions['datetime'] = pd.to_datetime(self.reactions['datetime']).dt.tz_localize(local_tz) + # except Exception as e: + # LOGGER.exception(e) + # LOGGER.warning(f'Error converting timezone to {local_tz}') + con.close() - async def scan_messages(self, client: Client, **kwargs): - self.reactions = await reaction_df(client, **kwargs) + async def write_sql(self): async with self.lock: with self.sql_context as con: self.reactions.to_sql( @@ -48,9 +65,14 @@ class MsgData: con=con, if_exists='replace', index=False, - index_label=self.reactions.index.name + # index_label=self.reactions.index.name ) - LOGGER.info(f'wrote {self.reactions.shape[0]:,} into {self.db_path.name}') + LOGGER.info(f'wrote {self.reactions.shape[0]:,} reactions into {self.db_path.name}') + + async def scan_messages(self, client: Client, **kwargs): + async with self.lock: + self.reactions = await reaction_df(client, **kwargs) + await self.write_sql() def most(self, emoji: str): matching = self.reactions['emoji'] == emoji @@ -65,38 +87,21 @@ class MsgData: channel = await guild.fetch_channel(row['channel_id']) return await channel.fetch_message(row['msg_id']) - def __repr__(self): - return f'<{__name__}.{self.__class__.__name__} with {self.reactions.shape[0]} reactions>' - - # async def add_msg(self, message: Message): - # async with self.lock: - # mdict = message_dict(message) - # mdict.pop('id') - # self.msgs.loc[message.id] = pd.Series(mdict) - # LOGGER.info(f'Added message id {message.id} from {message.author}: {message.content}') - async def update_reaction(self, msg: Message): # Drop all the reactions for this message id, if there are any try: async with self.lock: - self.reactions.drop(msg.id, level=0, axis=0, inplace=True) + self.reactions = self.reactions.loc[self.reactions['msg_id'] != msg.id] except KeyError as e: pass # If there are reactions on the message after the change if len(msg.reactions) > 0: - new = reaction_df(msg) + new = pd.DataFrame([reaction_dict(r) for r in msg.reactions]) async with self.lock: self.reactions = self.reactions.append(new) - LOGGER.info(str(new.droplevel(level=0, axis=0).loc[:, 'count'])) - - # if msg.id not in self.msgs.index: - # await self.add_msg(msg) - - return new - - async def emoji_user_counts(self, client: Client, emoji_name: str, days: int = None): - """Creates a Series indexed by user display_name with the number of reactions with emoji_name as values""" - counts: pd.Series = self.emoji_totals(emoji_name, days) - counts.index = pd.Index([(await client.fetch_user(user_id=uid)).display_name for uid in counts.index]) - return counts + try: + await self.write_sql() + except: + LOGGER.info(self.reactions.columns) + LOGGER.info(self.reactions.dtypes) diff --git a/kwaylon/kwaylon.py b/kwaylon/kwaylon.py index 389bb70..445f5da 100644 --- a/kwaylon/kwaylon.py +++ b/kwaylon/kwaylon.py @@ -5,6 +5,7 @@ from pathlib import Path import nextcord as discord from nextcord import Client, Message, TextChannel +from nextcord import RawReactionActionEvent from . import jokes from .data import MsgData @@ -15,7 +16,7 @@ LOGGER = logging.getLogger(__name__) class Kwaylon(Client): - db_path: Path = Path('../messages.db') + db_path: Path = Path('./messages.db') def __init__(self, limit: int = 5000, days: int = 30, *args, **kwargs): super().__init__(*args, **kwargs) @@ -36,7 +37,11 @@ class Kwaylon(Client): # await alive() - self.data = MsgData(path=Path('./messages.db')) + self.data = MsgData(self.db_path) + + await self.data.scan_messages(client=self, limit=100) + await self.data.write_sql() + await self.data.load_sql() if not hasattr(self.data, 'reactions'): await self.data.scan_messages(client=self, limit=self.limit, days=self.days) @@ -52,16 +57,6 @@ class Kwaylon(Client): await self.data.scan_messages(client=self, limit=self.limit, days=days) return - # if hasattr(self, 'data'): - # await self.data.add_msg(message) - # - # if (m := self.leaderboard_regex.match(message.content)) is not None: - # try: - # await message.reply(await self.leaderboard(match=m)) - # except KeyError as e: - # LOGGER.exception(e) - # await message.reply(f"I couldn't find any {m.group('emoji')} reactions. Leave me alone!") - # return if (m := self.most_regex.match(message.clean_content)) is not None: await self.data.load_sql() @@ -108,21 +103,21 @@ class Kwaylon(Client): for name, cnt in counts.iteritems()) return res - # async def handle_raw_reaction(self, payload: RawReactionActionEvent): - # LOGGER.info(payload) - # guild = await self.fetch_guild(payload.guild_id) - # channel = await guild.fetch_channel(payload.channel_id) - # message = await channel.fetch_message(payload.message_id) - # - # if payload.event_type == 'REACTION_REMOVE': - # LOGGER.info(f'{payload.emoji} removed from\n{message.author}: {message.content}') - # elif payload.event_type == 'REACTION_ADD': - # LOGGER.info( - # f'{payload.member.display_name} added {payload.emoji} to\n' + \ - # f'{message.author.display_name}: {message.content}') - # - # if hasattr(self, 'data'): - # await self.data.update_reaction(msg=message) + async def handle_raw_reaction(self, payload: RawReactionActionEvent): + LOGGER.info(payload) + guild = await self.fetch_guild(payload.guild_id) + channel = await guild.fetch_channel(payload.channel_id) + message = await channel.fetch_message(payload.message_id) + + if payload.event_type == 'REACTION_REMOVE': + LOGGER.info(f'{payload.emoji} removed from\n{message.author}: {message.content}') + elif payload.event_type == 'REACTION_ADD': + LOGGER.info( + f'{payload.member.display_name} added {payload.emoji} to\n' + \ + f'{message.author.display_name}: {message.content}') + + if hasattr(self, 'data'): + await self.data.update_reaction(msg=message) def get_emoji_name(string: str) -> str: diff --git a/kwaylon/msg.py b/kwaylon/msg.py index ba42ab7..13da027 100644 --- a/kwaylon/msg.py +++ b/kwaylon/msg.py @@ -34,19 +34,23 @@ async def message_gen(client: Client, limit=20, days: int = 90, **kwargs) -> Asy LOGGER.info(f'Done getting messages') +def reaction_dict(reaction: Reaction): + return { + 'msg_id': reaction.message.id, + 'emoji': reaction.emoji.name if reaction.is_custom_emoji() else reaction.emoji, + 'emoji_id': reaction.emoji.id if reaction.is_custom_emoji() else None, + 'channel_id': reaction.message.channel.id, + 'guild_id': reaction.message.channel.guild.id, + 'auth_id': reaction.message.author.id, + 'count': int(reaction.count), + 'datetime': reaction.message.created_at.astimezone(), + } + + async def reaction_gen(client: Client, **kwargs) -> AsyncIterator[Reaction]: async for msg in message_gen(client=client, **kwargs): for reaction in msg.reactions: - yield { - 'msg_id': reaction.message.id, - 'emoji': reaction.emoji.name if reaction.is_custom_emoji() else reaction.emoji, - 'emoji_id': reaction.emoji.id if reaction.is_custom_emoji() else None, - 'channel_id': msg.channel.id, - 'guild_id': msg.channel.guild.id, - 'auth_id': msg.author.id, - 'count': int(reaction.count), - 'datetime': msg.created_at.astimezone(), - } + yield reaction_dict(reaction) async def reaction_df(client: Client, **kwargs): diff --git a/main.py b/main.py index 1ce5ed9..5f1c827 100644 --- a/main.py +++ b/main.py @@ -2,6 +2,7 @@ import os import nextcord as discord from dotenv import load_dotenv +from nextcord import RawReactionActionEvent from kwaylon import Kwaylon @@ -12,7 +13,7 @@ if __name__ == '__main__': client = Kwaylon( # limit=100, - # days=10 + # days=60 ) @@ -26,14 +27,14 @@ if __name__ == '__main__': await client.handle_message(message) - # @client.event - # async def on_raw_reaction_add(payload: RawReactionActionEvent): - # await client.handle_raw_reaction(payload) - # - # - # @client.event - # async def on_raw_reaction_remove(payload: RawReactionActionEvent): - # await client.handle_raw_reaction(payload) + @client.event + async def on_raw_reaction_add(payload: RawReactionActionEvent): + await client.handle_raw_reaction(payload) + + + @client.event + async def on_raw_reaction_remove(payload: RawReactionActionEvent): + await client.handle_raw_reaction(payload) load_dotenv() client.run(os.getenv('DISCORD_TOKEN'))