from collections.abc import Iterable, Iterator from typing import cast import networkx as nx import numpy as np import numpy.typing as npt import sentence_transformers import sentence_transformers.util from networkx import Graph from pandas import Series from sentence_transformers import SentenceTransformer from torch import Tensor from lang_main.analysis.graphs import get_graph_metadata, update_graph from lang_main.types import PandasIndex def candidates_by_index( data_model_input: Series, model: SentenceTransformer, cos_sim_threshold: float = 0.5, # ) -> Iterator[tuple[PandasIndex, PandasIndex]]: ) -> Iterator[tuple[PandasIndex, PandasIndex]]: """function to filter candidate indices based on cosine similarity using SentenceTransformer model in batch mode, feed data as Series to retain information about indices of entries and access them later in the original dataset Parameters ---------- obj_id : ObjectID _description_ data_model_input : Series containing indices and text entries to process model : SentenceTransformer necessary SentenceTransformer model to encode text entries cos_sim_threshold : float, optional threshold for cosine similarity to filter candidates, by default 0.5 Yields ------ Iterator[tuple[PandasIndex, PandasIndex]] tuple of index pairs which meet the cosine similarity threshold """ # embeddings batch = cast(list[str], data_model_input.to_list()) embds = cast( Tensor, model.encode( batch, convert_to_numpy=False, convert_to_tensor=True, show_progress_bar=False, ), ) # cosine similarity cos_sim = cast(npt.NDArray, sentence_transformers.util.cos_sim(embds, embds).numpy()) np.fill_diagonal(cos_sim, 0.0) cos_sim = np.triu(cos_sim) cos_sim_idx = np.argwhere(cos_sim >= cos_sim_threshold) for idx_array in cos_sim_idx: idx_pair = cast( tuple[np.int64, np.int64], tuple(data_model_input.index[idx] for idx in idx_array) ) yield idx_pair def similar_index_connection_graph( similar_idx_pairs: Iterable[tuple[PandasIndex, PandasIndex]], ) -> tuple[Graph, dict[str, int]]: # build index graph to obtain graph of connected (similar) indices # use this graph to get connected components (indices which belong together) # retain semantic connection on whole dataset similar_id_graph = nx.Graph() # for idx1, idx2 in similar_idx_pairs: # # inplace operation, parent/child do not really exist in undirected graph # update_graph(graph=similar_id_graph, parent=idx1, child=idx2) update_graph(graph=similar_id_graph, batch=similar_idx_pairs) graph_info = get_graph_metadata(graph=similar_id_graph, logging=False) return similar_id_graph, graph_info def similar_index_groups( similar_id_graph: Graph, ) -> Iterator[tuple[PandasIndex, ...]]: # groups of connected indices ids_groups = cast(Iterator[set[PandasIndex]], nx.connected_components(G=similar_id_graph)) for id_group in ids_groups: yield tuple(id_group)