94 lines
3.1 KiB
Python
94 lines
3.1 KiB
Python
from collections.abc import Iterable, Iterator
|
|
from typing import cast
|
|
|
|
import networkx as nx
|
|
import numpy as np
|
|
import numpy.typing as npt
|
|
import sentence_transformers
|
|
import sentence_transformers.util
|
|
from networkx import Graph
|
|
from pandas import Series
|
|
from sentence_transformers import SentenceTransformer
|
|
from torch import Tensor
|
|
|
|
from lang_main.analysis.graphs import get_graph_metadata, update_graph
|
|
from lang_main.types import PandasIndex
|
|
|
|
|
|
def candidates_by_index(
|
|
data_model_input: Series,
|
|
model: SentenceTransformer,
|
|
cos_sim_threshold: float = 0.5,
|
|
# ) -> Iterator[tuple[PandasIndex, PandasIndex]]:
|
|
) -> Iterator[tuple[PandasIndex, PandasIndex]]:
|
|
"""function to filter candidate indices based on cosine similarity
|
|
using SentenceTransformer model in batch mode,
|
|
feed data as Series to retain information about indices of entries and
|
|
access them later in the original dataset
|
|
|
|
Parameters
|
|
----------
|
|
obj_id : ObjectID
|
|
_description_
|
|
data_model_input : Series
|
|
containing indices and text entries to process
|
|
model : SentenceTransformer
|
|
necessary SentenceTransformer model to encode text entries
|
|
cos_sim_threshold : float, optional
|
|
threshold for cosine similarity to filter candidates, by default 0.5
|
|
|
|
Yields
|
|
------
|
|
Iterator[tuple[PandasIndex, PandasIndex]]
|
|
tuple of index pairs which meet the cosine similarity threshold
|
|
"""
|
|
# embeddings
|
|
batch = cast(list[str], data_model_input.to_list())
|
|
embds = cast(
|
|
Tensor,
|
|
model.encode(
|
|
batch,
|
|
convert_to_numpy=False,
|
|
convert_to_tensor=True,
|
|
show_progress_bar=False,
|
|
),
|
|
)
|
|
# cosine similarity
|
|
cos_sim = cast(npt.NDArray, sentence_transformers.util.cos_sim(embds, embds).numpy())
|
|
np.fill_diagonal(cos_sim, 0.0)
|
|
cos_sim = np.triu(cos_sim)
|
|
cos_sim_idx = np.argwhere(cos_sim >= cos_sim_threshold)
|
|
|
|
for idx_array in cos_sim_idx:
|
|
idx_pair = cast(
|
|
tuple[np.int64, np.int64], tuple(data_model_input.index[idx] for idx in idx_array)
|
|
)
|
|
yield idx_pair
|
|
|
|
|
|
def similar_index_connection_graph(
|
|
similar_idx_pairs: Iterable[tuple[PandasIndex, PandasIndex]],
|
|
) -> tuple[Graph, dict[str, int]]:
|
|
# build index graph to obtain graph of connected (similar) indices
|
|
# use this graph to get connected components (indices which belong together)
|
|
# retain semantic connection on whole dataset
|
|
similar_id_graph = nx.Graph()
|
|
# for idx1, idx2 in similar_idx_pairs:
|
|
# # inplace operation, parent/child do not really exist in undirected graph
|
|
# update_graph(graph=similar_id_graph, parent=idx1, child=idx2)
|
|
update_graph(graph=similar_id_graph, batch=similar_idx_pairs)
|
|
|
|
graph_info = get_graph_metadata(graph=similar_id_graph, logging=False)
|
|
|
|
return similar_id_graph, graph_info
|
|
|
|
|
|
def similar_index_groups(
|
|
similar_id_graph: Graph,
|
|
) -> Iterator[tuple[PandasIndex, ...]]:
|
|
# groups of connected indices
|
|
ids_groups = cast(Iterator[set[PandasIndex]], nx.connected_components(G=similar_id_graph))
|
|
|
|
for id_group in ids_groups:
|
|
yield tuple(id_group)
|