|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import logging |
|
|
from collections.abc import Iterable |
|
|
from scipy.sparse import csr_matrix |
|
|
from scipy.spatial.distance import squareform |
|
|
|
|
|
|
|
|
class MyLogger: |
|
|
def __init__(self, level): |
|
|
self.logger = logging.getLogger('BERTopic') |
|
|
self.set_level(level) |
|
|
self._add_handler() |
|
|
self.logger.propagate = False |
|
|
|
|
|
def info(self, message): |
|
|
self.logger.info(f"{message}") |
|
|
|
|
|
def warning(self, message): |
|
|
self.logger.warning(f"WARNING: {message}") |
|
|
|
|
|
def set_level(self, level): |
|
|
levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] |
|
|
if level in levels: |
|
|
self.logger.setLevel(level) |
|
|
|
|
|
def _add_handler(self): |
|
|
sh = logging.StreamHandler() |
|
|
sh.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(message)s')) |
|
|
self.logger.addHandler(sh) |
|
|
|
|
|
|
|
|
if len(self.logger.handlers) > 1: |
|
|
self.logger.handlers = [self.logger.handlers[0]] |
|
|
|
|
|
|
|
|
def check_documents_type(documents): |
|
|
""" Check whether the input documents are indeed a list of strings """ |
|
|
if isinstance(documents, pd.DataFrame): |
|
|
raise TypeError("Make sure to supply a list of strings, not a dataframe.") |
|
|
elif isinstance(documents, Iterable) and not isinstance(documents, str): |
|
|
if not any([isinstance(doc, str) for doc in documents]): |
|
|
raise TypeError("Make sure that the iterable only contains strings.") |
|
|
else: |
|
|
raise TypeError("Make sure that the documents variable is an iterable containing strings only.") |
|
|
|
|
|
|
|
|
def check_embeddings_shape(embeddings, docs): |
|
|
""" Check if the embeddings have the correct shape """ |
|
|
if embeddings is not None: |
|
|
if not any([isinstance(embeddings, np.ndarray), isinstance(embeddings, csr_matrix)]): |
|
|
raise ValueError("Make sure to input embeddings as a numpy array or scipy.sparse.csr.csr_matrix. ") |
|
|
else: |
|
|
if embeddings.shape[0] != len(docs): |
|
|
raise ValueError("Make sure that the embeddings are a numpy array with shape: " |
|
|
"(len(docs), vector_dim) where vector_dim is the dimensionality " |
|
|
"of the vector embeddings. ") |
|
|
|
|
|
|
|
|
def check_is_fitted(topic_model): |
|
|
""" Checks if the model was fitted by verifying the presence of self.matches |
|
|
|
|
|
Arguments: |
|
|
model: BERTopic instance for which the check is performed. |
|
|
|
|
|
Returns: |
|
|
None |
|
|
|
|
|
Raises: |
|
|
ValueError: If the matches were not found. |
|
|
""" |
|
|
msg = ("This %(name)s instance is not fitted yet. Call 'fit' with " |
|
|
"appropriate arguments before using this estimator.") |
|
|
|
|
|
if topic_model.topics_ is None: |
|
|
raise ValueError(msg % {'name': type(topic_model).__name__}) |
|
|
|
|
|
|
|
|
class NotInstalled: |
|
|
""" |
|
|
This object is used to notify the user that additional dependencies need to be |
|
|
installed in order to use the string matching model. |
|
|
""" |
|
|
|
|
|
def __init__(self, tool, dep, custom_msg=None): |
|
|
self.tool = tool |
|
|
self.dep = dep |
|
|
|
|
|
msg = f"In order to use {self.tool} you will need to install via;\n\n" |
|
|
if custom_msg is not None: |
|
|
msg += custom_msg |
|
|
else: |
|
|
msg += f"pip install bertopic[{self.dep}]\n\n" |
|
|
self.msg = msg |
|
|
|
|
|
def __getattr__(self, *args, **kwargs): |
|
|
raise ModuleNotFoundError(self.msg) |
|
|
|
|
|
def __call__(self, *args, **kwargs): |
|
|
raise ModuleNotFoundError(self.msg) |
|
|
|
|
|
|
|
|
def validate_distance_matrix(X, n_samples): |
|
|
""" Validate the distance matrix and convert it to a condensed distance matrix |
|
|
if necessary. |
|
|
|
|
|
A valid distance matrix is either a square matrix of shape (n_samples, n_samples) |
|
|
with zeros on the diagonal and non-negative values or condensed distance matrix |
|
|
of shape (n_samples * (n_samples - 1) / 2,) containing the upper triangular of the |
|
|
distance matrix. |
|
|
|
|
|
Arguments: |
|
|
X: Distance matrix to validate. |
|
|
n_samples: Number of samples in the dataset. |
|
|
|
|
|
Returns: |
|
|
X: Validated distance matrix. |
|
|
|
|
|
Raises: |
|
|
ValueError: If the distance matrix is not valid. |
|
|
""" |
|
|
|
|
|
s = X.shape |
|
|
if len(s) == 1: |
|
|
|
|
|
n = s[0] |
|
|
if n != (n_samples * (n_samples - 1) / 2): |
|
|
raise ValueError("The condensed distance matrix must have " |
|
|
"shape (n*(n-1)/2,).") |
|
|
elif len(s) == 2: |
|
|
|
|
|
if (s[0] != n_samples) or (s[1] != n_samples): |
|
|
raise ValueError("The distance matrix must be of shape " |
|
|
"(n, n) where n is the number of samples.") |
|
|
|
|
|
np.fill_diagonal(X, 0) |
|
|
X = squareform(X) |
|
|
else: |
|
|
raise ValueError("The distance matrix must be either a 1-D condensed " |
|
|
"distance matrix of shape (n*(n-1)/2,) or a " |
|
|
"2-D square distance matrix of shape (n, n)." |
|
|
"where n is the number of documents." |
|
|
"Got a distance matrix of shape %s" % str(s)) |
|
|
|
|
|
|
|
|
if np.any(X < 0): |
|
|
raise ValueError("Distance matrix cannot contain negative values.") |
|
|
|
|
|
return X |
|
|
|