2024-06-05 16:37:23 +02:00

94 lines
3.1 KiB
Python

from collections.abc import Iterable, Iterator
from typing import cast
import networkx as nx
import numpy as np
import numpy.typing as npt
import sentence_transformers
import sentence_transformers.util
from networkx import Graph
from pandas import Series
from sentence_transformers import SentenceTransformer
from torch import Tensor
from lang_main.analysis.graphs import get_graph_metadata, update_graph
from lang_main.types import PandasIndex
def candidates_by_index(
data_model_input: Series,
model: SentenceTransformer,
cos_sim_threshold: float = 0.5,
# ) -> Iterator[tuple[PandasIndex, PandasIndex]]:
) -> Iterator[tuple[PandasIndex, PandasIndex]]:
"""function to filter candidate indices based on cosine similarity
using SentenceTransformer model in batch mode,
feed data as Series to retain information about indices of entries and
access them later in the original dataset
Parameters
----------
obj_id : ObjectID
_description_
data_model_input : Series
containing indices and text entries to process
model : SentenceTransformer
necessary SentenceTransformer model to encode text entries
cos_sim_threshold : float, optional
threshold for cosine similarity to filter candidates, by default 0.5
Yields
------
Iterator[tuple[PandasIndex, PandasIndex]]
tuple of index pairs which meet the cosine similarity threshold
"""
# embeddings
batch = cast(list[str], data_model_input.to_list())
embds = cast(
Tensor,
model.encode(
batch,
convert_to_numpy=False,
convert_to_tensor=True,
show_progress_bar=False,
),
)
# cosine similarity
cos_sim = cast(npt.NDArray, sentence_transformers.util.cos_sim(embds, embds).numpy())
np.fill_diagonal(cos_sim, 0.0)
cos_sim = np.triu(cos_sim)
cos_sim_idx = np.argwhere(cos_sim >= cos_sim_threshold)
for idx_array in cos_sim_idx:
idx_pair = cast(
tuple[np.int64, np.int64], tuple(data_model_input.index[idx] for idx in idx_array)
)
yield idx_pair
def similar_index_connection_graph(
similar_idx_pairs: Iterable[tuple[PandasIndex, PandasIndex]],
) -> tuple[Graph, dict[str, int]]:
# build index graph to obtain graph of connected (similar) indices
# use this graph to get connected components (indices which belong together)
# retain semantic connection on whole dataset
similar_id_graph = nx.Graph()
# for idx1, idx2 in similar_idx_pairs:
# # inplace operation, parent/child do not really exist in undirected graph
# update_graph(graph=similar_id_graph, parent=idx1, child=idx2)
update_graph(graph=similar_id_graph, batch=similar_idx_pairs)
graph_info = get_graph_metadata(graph=similar_id_graph, logging=False)
return similar_id_graph, graph_info
def similar_index_groups(
similar_id_graph: Graph,
) -> Iterator[tuple[PandasIndex, ...]]:
# groups of connected indices
ids_groups = cast(Iterator[set[PandasIndex]], nx.connected_components(G=similar_id_graph))
for id_group in ids_groups:
yield tuple(id_group)