File size: 5,399 Bytes
19b102a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import numpy as np
import pandas as pd
import logging
from collections.abc import Iterable
from scipy.sparse import csr_matrix
from scipy.spatial.distance import squareform
class MyLogger:
def __init__(self, level):
self.logger = logging.getLogger('BERTopic')
self.set_level(level)
self._add_handler()
self.logger.propagate = False
def info(self, message):
self.logger.info(f"{message}")
def warning(self, message):
self.logger.warning(f"WARNING: {message}")
def set_level(self, level):
levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
if level in levels:
self.logger.setLevel(level)
def _add_handler(self):
sh = logging.StreamHandler()
sh.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(message)s'))
self.logger.addHandler(sh)
# Remove duplicate handlers
if len(self.logger.handlers) > 1:
self.logger.handlers = [self.logger.handlers[0]]
def check_documents_type(documents):
""" Check whether the input documents are indeed a list of strings """
if isinstance(documents, pd.DataFrame):
raise TypeError("Make sure to supply a list of strings, not a dataframe.")
elif isinstance(documents, Iterable) and not isinstance(documents, str):
if not any([isinstance(doc, str) for doc in documents]):
raise TypeError("Make sure that the iterable only contains strings.")
else:
raise TypeError("Make sure that the documents variable is an iterable containing strings only.")
def check_embeddings_shape(embeddings, docs):
""" Check if the embeddings have the correct shape """
if embeddings is not None:
if not any([isinstance(embeddings, np.ndarray), isinstance(embeddings, csr_matrix)]):
raise ValueError("Make sure to input embeddings as a numpy array or scipy.sparse.csr.csr_matrix. ")
else:
if embeddings.shape[0] != len(docs):
raise ValueError("Make sure that the embeddings are a numpy array with shape: "
"(len(docs), vector_dim) where vector_dim is the dimensionality "
"of the vector embeddings. ")
def check_is_fitted(topic_model):
""" Checks if the model was fitted by verifying the presence of self.matches
Arguments:
model: BERTopic instance for which the check is performed.
Returns:
None
Raises:
ValueError: If the matches were not found.
"""
msg = ("This %(name)s instance is not fitted yet. Call 'fit' with "
"appropriate arguments before using this estimator.")
if topic_model.topics_ is None:
raise ValueError(msg % {'name': type(topic_model).__name__})
class NotInstalled:
"""
This object is used to notify the user that additional dependencies need to be
installed in order to use the string matching model.
"""
def __init__(self, tool, dep, custom_msg=None):
self.tool = tool
self.dep = dep
msg = f"In order to use {self.tool} you will need to install via;\n\n"
if custom_msg is not None:
msg += custom_msg
else:
msg += f"pip install bertopic[{self.dep}]\n\n"
self.msg = msg
def __getattr__(self, *args, **kwargs):
raise ModuleNotFoundError(self.msg)
def __call__(self, *args, **kwargs):
raise ModuleNotFoundError(self.msg)
def validate_distance_matrix(X, n_samples):
""" Validate the distance matrix and convert it to a condensed distance matrix
if necessary.
A valid distance matrix is either a square matrix of shape (n_samples, n_samples)
with zeros on the diagonal and non-negative values or condensed distance matrix
of shape (n_samples * (n_samples - 1) / 2,) containing the upper triangular of the
distance matrix.
Arguments:
X: Distance matrix to validate.
n_samples: Number of samples in the dataset.
Returns:
X: Validated distance matrix.
Raises:
ValueError: If the distance matrix is not valid.
"""
# Make sure it is the 1-D condensed distance matrix with zeros on the diagonal
s = X.shape
if len(s) == 1:
# check it has correct size
n = s[0]
if n != (n_samples * (n_samples - 1) / 2):
raise ValueError("The condensed distance matrix must have "
"shape (n*(n-1)/2,).")
elif len(s) == 2:
# check it has correct size
if (s[0] != n_samples) or (s[1] != n_samples):
raise ValueError("The distance matrix must be of shape "
"(n, n) where n is the number of samples.")
# force zero diagonal and convert to condensed
np.fill_diagonal(X, 0)
X = squareform(X)
else:
raise ValueError("The distance matrix must be either a 1-D condensed "
"distance matrix of shape (n*(n-1)/2,) or a "
"2-D square distance matrix of shape (n, n)."
"where n is the number of documents."
"Got a distance matrix of shape %s" % str(s))
# Make sure its entries are non-negative
if np.any(X < 0):
raise ValueError("Distance matrix cannot contain negative values.")
return X
|