Files
kwaylon/kwaylon/data.py
2022-01-22 02:07:06 -06:00

103 lines
3.8 KiB
Python

import asyncio
import logging
import sqlite3
from pathlib import Path
from typing import Union
import pandas as pd
from nextcord import Client, Message
from .msg import LOGGER, reaction_df
LOGGER = logging.getLogger(__name__)
class MsgData:
db_path: Path
msgs: pd.DataFrame
reactions: pd.DataFrame
lock: asyncio.Lock
def __init__(self, path: Union[str, Path]):
self.lock = asyncio.Lock()
self.db_path: Path = Path(path) if isinstance(path, str) else path
@property
def sql_context(self):
return sqlite3.connect(self.db_path)
async def load_sql(self, local_tz='US/Central'):
async with self.lock:
with self.sql_context as con:
LOGGER.info(f'Opened {self.db_path.name}')
try:
self.reactions = pd.read_sql('select * from reactions', con=con).reset_index()
self.reactions['datetime'] = pd.to_datetime(self.reactions['datetime']).dt.tz_convert(local_tz)
except:
LOGGER.warning(f'failed to read reactions from: {self.db_path.resolve()}')
else:
LOGGER.info(f'read {self.reactions.shape[0]:,} reactions')
con.close()
async def scan_messages(self, client: Client, **kwargs):
self.reactions = await reaction_df(client, **kwargs)
async with self.lock:
with self.sql_context as con:
self.reactions.to_sql(
name='reactions',
con=con,
if_exists='replace',
index=False,
index_label=self.reactions.index.name
)
LOGGER.info(f'wrote {self.reactions.shape[0]:,} into {self.db_path.name}')
def most(self, emoji: str):
matching = self.reactions['emoji'] == emoji
if not matching.any():
LOGGER.info(f'No reactions with {emoji}')
return
else:
return self.reactions.loc[matching].sort_values('count', ascending=False).reset_index(drop=True)
async def fetch_message(self, client: Client, row: pd.Series):
guild = await client.fetch_guild(row['guild_id'])
channel = await guild.fetch_channel(row['channel_id'])
return await channel.fetch_message(row['msg_id'])
def __repr__(self):
return f'<{__name__}.{self.__class__.__name__} with {self.reactions.shape[0]} reactions>'
# async def add_msg(self, message: Message):
# async with self.lock:
# mdict = message_dict(message)
# mdict.pop('id')
# self.msgs.loc[message.id] = pd.Series(mdict)
# LOGGER.info(f'Added message id {message.id} from {message.author}: {message.content}')
async def update_reaction(self, msg: Message):
# Drop all the reactions for this message id, if there are any
try:
async with self.lock:
self.reactions.drop(msg.id, level=0, axis=0, inplace=True)
except KeyError as e:
pass
# If there are reactions on the message after the change
if len(msg.reactions) > 0:
new = reaction_df(msg)
async with self.lock:
self.reactions = self.reactions.append(new)
LOGGER.info(str(new.droplevel(level=0, axis=0).loc[:, 'count']))
# if msg.id not in self.msgs.index:
# await self.add_msg(msg)
return new
async def emoji_user_counts(self, client: Client, emoji_name: str, days: int = None):
"""Creates a Series indexed by user display_name with the number of reactions with emoji_name as values"""
counts: pd.Series = self.emoji_totals(emoji_name, days)
counts.index = pd.Index([(await client.fetch_user(user_id=uid)).display_name for uid in counts.index])
return counts