handling reactions again

This commit is contained in:
jsl12
2022-01-23 13:27:35 -06:00
parent b0515954af
commit fa89302e54
4 changed files with 81 additions and 76 deletions

View File

@@ -8,11 +8,13 @@ import pandas as pd
from nextcord import Client, Message from nextcord import Client, Message
from .msg import LOGGER, reaction_df from .msg import LOGGER, reaction_df
from .msg import reaction_dict
LOGGER = logging.getLogger(__name__) LOGGER = logging.getLogger(__name__)
class MsgData: class MsgData:
"""Wrapper class to manage saving and loading the DataFrame of reactions"""
db_path: Path db_path: Path
msgs: pd.DataFrame msgs: pd.DataFrame
reactions: pd.DataFrame reactions: pd.DataFrame
@@ -22,6 +24,9 @@ class MsgData:
self.lock = asyncio.Lock() self.lock = asyncio.Lock()
self.db_path: Path = Path(path) if isinstance(path, str) else path self.db_path: Path = Path(path) if isinstance(path, str) else path
def __repr__(self):
return f'<{__name__}.{self.__class__.__name__} with {self.reactions.shape[0]} reactions>'
@property @property
def sql_context(self): def sql_context(self):
return sqlite3.connect(self.db_path) return sqlite3.connect(self.db_path)
@@ -31,16 +36,28 @@ class MsgData:
with self.sql_context as con: with self.sql_context as con:
LOGGER.info(f'Opened {self.db_path.name}') LOGGER.info(f'Opened {self.db_path.name}')
try: try:
self.reactions = pd.read_sql('select * from reactions', con=con).reset_index() self.reactions = pd.read_sql('select * from reactions', con=con).reset_index(drop=True)
self.reactions['datetime'] = pd.to_datetime(self.reactions['datetime']).dt.tz_convert(local_tz)
except: except:
LOGGER.warning(f'failed to read reactions from: {self.db_path.resolve()}') LOGGER.warning(f'failed to read reactions from: {self.db_path.resolve()}')
else: else:
LOGGER.info(f'read {self.reactions.shape[0]:,} reactions') LOGGER.info(f'read {self.reactions.shape[0]:,} reactions')
self.reactions['datetime'] = pd.to_datetime(self.reactions['datetime'])
LOGGER.info(f"'datetime' dtype: {self.reactions['datetime'].dtype}")
LOGGER.info(f"{self.reactions['datetime'].values[:3]}...")
# try:
# self.reactions['datetime'] = pd.to_datetime(self.reactions['datetime']).dt.tz_convert(local_tz)
# except Exception as e:
# LOGGER.exception(e)
# try:
# self.reactions['datetime'] = pd.to_datetime(self.reactions['datetime']).dt.tz_localize(local_tz)
# except Exception as e:
# LOGGER.exception(e)
# LOGGER.warning(f'Error converting timezone to {local_tz}')
con.close() con.close()
async def scan_messages(self, client: Client, **kwargs): async def write_sql(self):
self.reactions = await reaction_df(client, **kwargs)
async with self.lock: async with self.lock:
with self.sql_context as con: with self.sql_context as con:
self.reactions.to_sql( self.reactions.to_sql(
@@ -48,9 +65,14 @@ class MsgData:
con=con, con=con,
if_exists='replace', if_exists='replace',
index=False, index=False,
index_label=self.reactions.index.name # index_label=self.reactions.index.name
) )
LOGGER.info(f'wrote {self.reactions.shape[0]:,} into {self.db_path.name}') LOGGER.info(f'wrote {self.reactions.shape[0]:,} reactions into {self.db_path.name}')
async def scan_messages(self, client: Client, **kwargs):
async with self.lock:
self.reactions = await reaction_df(client, **kwargs)
await self.write_sql()
def most(self, emoji: str): def most(self, emoji: str):
matching = self.reactions['emoji'] == emoji matching = self.reactions['emoji'] == emoji
@@ -65,38 +87,21 @@ class MsgData:
channel = await guild.fetch_channel(row['channel_id']) channel = await guild.fetch_channel(row['channel_id'])
return await channel.fetch_message(row['msg_id']) return await channel.fetch_message(row['msg_id'])
def __repr__(self):
return f'<{__name__}.{self.__class__.__name__} with {self.reactions.shape[0]} reactions>'
# async def add_msg(self, message: Message):
# async with self.lock:
# mdict = message_dict(message)
# mdict.pop('id')
# self.msgs.loc[message.id] = pd.Series(mdict)
# LOGGER.info(f'Added message id {message.id} from {message.author}: {message.content}')
async def update_reaction(self, msg: Message): async def update_reaction(self, msg: Message):
# Drop all the reactions for this message id, if there are any # Drop all the reactions for this message id, if there are any
try: try:
async with self.lock: async with self.lock:
self.reactions.drop(msg.id, level=0, axis=0, inplace=True) self.reactions = self.reactions.loc[self.reactions['msg_id'] != msg.id]
except KeyError as e: except KeyError as e:
pass pass
# If there are reactions on the message after the change # If there are reactions on the message after the change
if len(msg.reactions) > 0: if len(msg.reactions) > 0:
new = reaction_df(msg) new = pd.DataFrame([reaction_dict(r) for r in msg.reactions])
async with self.lock: async with self.lock:
self.reactions = self.reactions.append(new) self.reactions = self.reactions.append(new)
LOGGER.info(str(new.droplevel(level=0, axis=0).loc[:, 'count'])) try:
await self.write_sql()
# if msg.id not in self.msgs.index: except:
# await self.add_msg(msg) LOGGER.info(self.reactions.columns)
LOGGER.info(self.reactions.dtypes)
return new
async def emoji_user_counts(self, client: Client, emoji_name: str, days: int = None):
"""Creates a Series indexed by user display_name with the number of reactions with emoji_name as values"""
counts: pd.Series = self.emoji_totals(emoji_name, days)
counts.index = pd.Index([(await client.fetch_user(user_id=uid)).display_name for uid in counts.index])
return counts

View File

@@ -5,6 +5,7 @@ from pathlib import Path
import nextcord as discord import nextcord as discord
from nextcord import Client, Message, TextChannel from nextcord import Client, Message, TextChannel
from nextcord import RawReactionActionEvent
from . import jokes from . import jokes
from .data import MsgData from .data import MsgData
@@ -15,7 +16,7 @@ LOGGER = logging.getLogger(__name__)
class Kwaylon(Client): class Kwaylon(Client):
db_path: Path = Path('../messages.db') db_path: Path = Path('./messages.db')
def __init__(self, limit: int = 5000, days: int = 30, *args, **kwargs): def __init__(self, limit: int = 5000, days: int = 30, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@@ -36,7 +37,11 @@ class Kwaylon(Client):
# await alive() # await alive()
self.data = MsgData(path=Path('./messages.db')) self.data = MsgData(self.db_path)
await self.data.scan_messages(client=self, limit=100)
await self.data.write_sql()
await self.data.load_sql() await self.data.load_sql()
if not hasattr(self.data, 'reactions'): if not hasattr(self.data, 'reactions'):
await self.data.scan_messages(client=self, limit=self.limit, days=self.days) await self.data.scan_messages(client=self, limit=self.limit, days=self.days)
@@ -52,16 +57,6 @@ class Kwaylon(Client):
await self.data.scan_messages(client=self, limit=self.limit, days=days) await self.data.scan_messages(client=self, limit=self.limit, days=days)
return return
# if hasattr(self, 'data'):
# await self.data.add_msg(message)
#
# if (m := self.leaderboard_regex.match(message.content)) is not None:
# try:
# await message.reply(await self.leaderboard(match=m))
# except KeyError as e:
# LOGGER.exception(e)
# await message.reply(f"I couldn't find any {m.group('emoji')} reactions. Leave me alone!")
# return
if (m := self.most_regex.match(message.clean_content)) is not None: if (m := self.most_regex.match(message.clean_content)) is not None:
await self.data.load_sql() await self.data.load_sql()
@@ -108,21 +103,21 @@ class Kwaylon(Client):
for name, cnt in counts.iteritems()) for name, cnt in counts.iteritems())
return res return res
# async def handle_raw_reaction(self, payload: RawReactionActionEvent): async def handle_raw_reaction(self, payload: RawReactionActionEvent):
# LOGGER.info(payload) LOGGER.info(payload)
# guild = await self.fetch_guild(payload.guild_id) guild = await self.fetch_guild(payload.guild_id)
# channel = await guild.fetch_channel(payload.channel_id) channel = await guild.fetch_channel(payload.channel_id)
# message = await channel.fetch_message(payload.message_id) message = await channel.fetch_message(payload.message_id)
#
# if payload.event_type == 'REACTION_REMOVE': if payload.event_type == 'REACTION_REMOVE':
# LOGGER.info(f'{payload.emoji} removed from\n{message.author}: {message.content}') LOGGER.info(f'{payload.emoji} removed from\n{message.author}: {message.content}')
# elif payload.event_type == 'REACTION_ADD': elif payload.event_type == 'REACTION_ADD':
# LOGGER.info( LOGGER.info(
# f'{payload.member.display_name} added {payload.emoji} to\n' + \ f'{payload.member.display_name} added {payload.emoji} to\n' + \
# f'{message.author.display_name}: {message.content}') f'{message.author.display_name}: {message.content}')
#
# if hasattr(self, 'data'): if hasattr(self, 'data'):
# await self.data.update_reaction(msg=message) await self.data.update_reaction(msg=message)
def get_emoji_name(string: str) -> str: def get_emoji_name(string: str) -> str:

View File

@@ -34,19 +34,23 @@ async def message_gen(client: Client, limit=20, days: int = 90, **kwargs) -> Asy
LOGGER.info(f'Done getting messages') LOGGER.info(f'Done getting messages')
def reaction_dict(reaction: Reaction):
return {
'msg_id': reaction.message.id,
'emoji': reaction.emoji.name if reaction.is_custom_emoji() else reaction.emoji,
'emoji_id': reaction.emoji.id if reaction.is_custom_emoji() else None,
'channel_id': reaction.message.channel.id,
'guild_id': reaction.message.channel.guild.id,
'auth_id': reaction.message.author.id,
'count': int(reaction.count),
'datetime': reaction.message.created_at.astimezone(),
}
async def reaction_gen(client: Client, **kwargs) -> AsyncIterator[Reaction]: async def reaction_gen(client: Client, **kwargs) -> AsyncIterator[Reaction]:
async for msg in message_gen(client=client, **kwargs): async for msg in message_gen(client=client, **kwargs):
for reaction in msg.reactions: for reaction in msg.reactions:
yield { yield reaction_dict(reaction)
'msg_id': reaction.message.id,
'emoji': reaction.emoji.name if reaction.is_custom_emoji() else reaction.emoji,
'emoji_id': reaction.emoji.id if reaction.is_custom_emoji() else None,
'channel_id': msg.channel.id,
'guild_id': msg.channel.guild.id,
'auth_id': msg.author.id,
'count': int(reaction.count),
'datetime': msg.created_at.astimezone(),
}
async def reaction_df(client: Client, **kwargs): async def reaction_df(client: Client, **kwargs):

19
main.py
View File

@@ -2,6 +2,7 @@ import os
import nextcord as discord import nextcord as discord
from dotenv import load_dotenv from dotenv import load_dotenv
from nextcord import RawReactionActionEvent
from kwaylon import Kwaylon from kwaylon import Kwaylon
@@ -12,7 +13,7 @@ if __name__ == '__main__':
client = Kwaylon( client = Kwaylon(
# limit=100, # limit=100,
# days=10 # days=60
) )
@@ -26,14 +27,14 @@ if __name__ == '__main__':
await client.handle_message(message) await client.handle_message(message)
# @client.event @client.event
# async def on_raw_reaction_add(payload: RawReactionActionEvent): async def on_raw_reaction_add(payload: RawReactionActionEvent):
# await client.handle_raw_reaction(payload) await client.handle_raw_reaction(payload)
#
#
# @client.event @client.event
# async def on_raw_reaction_remove(payload: RawReactionActionEvent): async def on_raw_reaction_remove(payload: RawReactionActionEvent):
# await client.handle_raw_reaction(payload) await client.handle_raw_reaction(payload)
load_dotenv() load_dotenv()
client.run(os.getenv('DISCORD_TOKEN')) client.run(os.getenv('DISCORD_TOKEN'))