From fb28b8548be00ff1b0e5802fec09b5494ff7e656 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florian=20F=C3=B6rster?= Date: Wed, 22 Jan 2025 16:54:15 +0100 Subject: [PATCH] added test cases --- docs/lang_main/analysis/graphs.html | 1844 +++ docs/lang_main/analysis/index.html | 98 + docs/lang_main/analysis/preprocessing.html | 451 + docs/lang_main/analysis/shared.html | 273 + docs/lang_main/analysis/timeline.html | 333 + docs/lang_main/analysis/tokens.html | 320 + docs/lang_main/config.html | 206 + docs/lang_main/constants.html | 66 + docs/lang_main/errors.html | 330 + docs/lang_main/index.html | 123 + docs/lang_main/io.html | 227 + docs/lang_main/loggers.html | 66 + docs/lang_main/model_loader.html | 162 + docs/lang_main/pipelines/base.html | 755 ++ docs/lang_main/pipelines/index.html | 83 + docs/lang_main/pipelines/predefined.html | 386 + docs/lang_main/render/cytoscape.html | 797 ++ .../render/cytoscape_monkeypatch.html | 182 + docs/lang_main/render/index.html | 83 + docs/lang_main/search.html | 261 + docs/lang_main/types.html | 10637 ++++++++++++++++ src/lang_main/analysis/shared.py | 3 - src/lang_main/analysis/timeline.py | 2 +- src/lang_main/model_loader.py | 2 +- src/lang_main/pipelines/predefined.py | 9 +- tests/analysis/test_graphs.py | 4 +- tests/analysis/test_timeline.py | 4 +- tests/test_model_loader.py | 31 +- 28 files changed, 17721 insertions(+), 17 deletions(-) create mode 100644 docs/lang_main/analysis/graphs.html create mode 100644 docs/lang_main/analysis/index.html create mode 100644 docs/lang_main/analysis/preprocessing.html create mode 100644 docs/lang_main/analysis/shared.html create mode 100644 docs/lang_main/analysis/timeline.html create mode 100644 docs/lang_main/analysis/tokens.html create mode 100644 docs/lang_main/config.html create mode 100644 docs/lang_main/constants.html create mode 100644 docs/lang_main/errors.html create mode 100644 docs/lang_main/index.html create mode 100644 docs/lang_main/io.html create mode 100644 docs/lang_main/loggers.html create mode 100644 docs/lang_main/model_loader.html create mode 100644 docs/lang_main/pipelines/base.html create mode 100644 docs/lang_main/pipelines/index.html create mode 100644 docs/lang_main/pipelines/predefined.html create mode 100644 docs/lang_main/render/cytoscape.html create mode 100644 docs/lang_main/render/cytoscape_monkeypatch.html create mode 100644 docs/lang_main/render/index.html create mode 100644 docs/lang_main/search.html create mode 100644 docs/lang_main/types.html diff --git a/docs/lang_main/analysis/graphs.html b/docs/lang_main/analysis/graphs.html new file mode 100644 index 0000000..6c54b7a --- /dev/null +++ b/docs/lang_main/analysis/graphs.html @@ -0,0 +1,1844 @@ + + + + + + +lang_main.analysis.graphs API documentation + + + + + + + + + + + +
+
+
+

Module lang_main.analysis.graphs

+
+
+
+
+
+
+
+
+

Functions

+
+
+def add_betweenness_centrality(graph: DiGraph | Graph,
edge_weight_property: str | None = None,
property_name: str = 'betweenness_centrality') ‑> None
+
+
+
+ +Expand source code + +
def add_betweenness_centrality(
+    graph: DiGraph | Graph,
+    edge_weight_property: str | None = None,
+    property_name: str = PROPERTY_NAME_BETWEENNESS_CENTRALITY,
+) -> None:
+    """adds the betweenness centrality as property to each node of the given graph
+    Operation is performed inplace.
+
+    Parameters
+    ----------
+    graph : DiGraph | Graph
+        Graph with betweenness centrality as node property added inplace
+    edge_weight_property : str | None, optional
+        property of the edges which contains the weight information,
+        not necessarily needed, by default 'None'
+    property_name : str, optional
+        target name for the property containing the betweenness centrality in nodes,
+        by default PROPERTY_NAME_BETWEENNESS_CENTRALITY
+    """
+
+    node_property_mapping = cast(
+        dict[str, float],
+        nx.betweenness_centrality(graph, normalized=True, weight=edge_weight_property),  # type: ignore
+    )
+    nx.set_node_attributes(
+        graph,
+        node_property_mapping,
+        name=property_name,
+    )
+
+

adds the betweenness centrality as property to each node of the given graph +Operation is performed inplace.

+

Parameters

+
+
graph : DiGraph | Graph
+
Graph with betweenness centrality as node property added inplace
+
edge_weight_property : str | None, optional
+
property of the edges which contains the weight information, +not necessarily needed, by default 'None'
+
property_name : str, optional
+
target name for the property containing the betweenness centrality in nodes, +by default PROPERTY_NAME_BETWEENNESS_CENTRALITY
+
+
+
+def add_importance_metric(graph: DiGraph | Graph,
property_name: str = 'importance',
property_name_weighted_degree: str = 'degree_weighted',
property_name_betweenness: str = 'betweenness_centrality') ‑> None
+
+
+
+ +Expand source code + +
def add_importance_metric(
+    graph: DiGraph | Graph,
+    property_name: str = PROPERTY_NAME_IMPORTANCE,
+    property_name_weighted_degree: str = PROPERTY_NAME_DEGREE_WEIGHTED,
+    property_name_betweenness: str = PROPERTY_NAME_BETWEENNESS_CENTRALITY,
+) -> None:
+    """Adds a custom importance metric as property to each node of the given graph.
+    Can be used to decide which nodes are of high importance and also to build node size
+    mappings.
+    Operation is performed inplace.
+
+    Parameters
+    ----------
+    graph : DiGraph | Graph
+        Graph with weighted degree as node property added inplace
+    property_name : str, optional
+        target name for the property containing the weighted degree in nodes,
+        by default PROPERTY_NAME_DEGREE_WEIGHTED
+    property_name_betweenness : str, optional
+        target name for the property containing the betweenness centrality in nodes,
+        by default PROPERTY_NAME_BETWEENNESS_CENTRALITY
+    """
+    # build mapping for importance metric
+    node_property_mapping: dict[str, float] = {}
+    for node in cast(Iterable[str], graph.nodes):
+        node_data = cast(dict[str, float], graph.nodes[node])
+
+        if property_name_weighted_degree not in node_data:
+            raise NodePropertyNotContainedError(
+                (
+                    f'Node data does not contain weighted degree '
+                    f'with name {property_name_weighted_degree}.'
+                )
+            )
+        elif property_name_betweenness not in node_data:
+            raise NodePropertyNotContainedError(
+                (
+                    f'Node data does not contain betweenness centrality '
+                    f'with name {property_name_betweenness}.'
+                )
+            )
+
+        prio = node_data[property_name_weighted_degree] * node_data[property_name_betweenness]
+        node_property_mapping[node] = prio
+
+    nx.set_node_attributes(
+        graph,
+        node_property_mapping,
+        name=property_name,
+    )
+
+

Adds a custom importance metric as property to each node of the given graph. +Can be used to decide which nodes are of high importance and also to build node size +mappings. +Operation is performed inplace.

+

Parameters

+
+
graph : DiGraph | Graph
+
Graph with weighted degree as node property added inplace
+
property_name : str, optional
+
target name for the property containing the weighted degree in nodes, +by default PROPERTY_NAME_DEGREE_WEIGHTED
+
property_name_betweenness : str, optional
+
target name for the property containing the betweenness centrality in nodes, +by default PROPERTY_NAME_BETWEENNESS_CENTRALITY
+
+
+
+def add_weighted_degree(graph: DiGraph | Graph,
edge_weight_property: str = 'weight',
property_name: str = 'degree_weighted') ‑> None
+
+
+
+ +Expand source code + +
def add_weighted_degree(
+    graph: DiGraph | Graph,
+    edge_weight_property: str = 'weight',
+    property_name: str = PROPERTY_NAME_DEGREE_WEIGHTED,
+) -> None:
+    """adds the weighted degree as property to each node of the given graph
+    Operation is performed inplace.
+
+    Parameters
+    ----------
+    graph : DiGraph | Graph
+        Graph with weighted degree as node property added inplace
+    edge_weight_property : str, optional
+        property of the edges which contains the weight information, by default 'weight'
+    property_name : str, optional
+        target name for the property containing the weighted degree in nodes,
+        by default PROPERTY_NAME_DEGREE_WEIGHTED
+    """
+    node_property_mapping = cast(
+        dict[str, float],
+        dict(graph.degree(weight=edge_weight_property)),  # type: ignore
+    )
+    nx.set_node_attributes(
+        graph,
+        node_property_mapping,
+        name=property_name,
+    )
+
+

adds the weighted degree as property to each node of the given graph +Operation is performed inplace.

+

Parameters

+
+
graph : DiGraph | Graph
+
Graph with weighted degree as node property added inplace
+
edge_weight_property : str, optional
+
property of the edges which contains the weight information, by default 'weight'
+
property_name : str, optional
+
target name for the property containing the weighted degree in nodes, +by default PROPERTY_NAME_DEGREE_WEIGHTED
+
+
+
+def convert_graph_to_cytoscape(graph: Graph | DiGraph) ‑> tuple[list[lang_main.types.CytoscapeData], lang_main.types.WeightData] +
+
+
+ +Expand source code + +
def convert_graph_to_cytoscape(
+    graph: Graph | DiGraph,
+) -> tuple[list[CytoscapeData], WeightData]:
+    cyto_data: list[CytoscapeData] = []
+    # iterate over nodes
+    nodes = cast(Iterable[NodeTitle], graph.nodes)
+    for node in nodes:
+        node_data: CytoscapeData = {
+            'data': {
+                'id': node,
+                'label': node,
+            }
+        }
+        cyto_data.append(node_data)
+    # iterate over edges
+    weights: set[int] = set()
+
+    edges = cast(
+        Iterable[
+            tuple[
+                NodeTitle,
+                NodeTitle,
+                EdgeWeight,
+            ]
+        ],
+        graph.edges.data('weight', default=1),  # type: ignore
+    )
+    for source, target, weight in edges:
+        weights.add(weight)
+        edge_data: CytoscapeData = {
+            'data': {
+                'source': source,
+                'target': target,
+                'weight': weight,
+            }
+        }
+        cyto_data.append(edge_data)
+
+    # TODO: add internal behaviour (if edge added check for new min/max)
+    min_weight: int = 0
+    max_weight: int = 0
+    if weights:
+        min_weight = min(weights)
+        max_weight = max(weights)
+    weight_metadata: WeightData = {'min': min_weight, 'max': max_weight}
+
+    return cyto_data, weight_metadata
+
+
+
+
+def convert_graph_to_undirected(graph: DiGraph, logging: bool = False, cast_int: bool = False) ‑> networkx.classes.graph.Graph +
+
+
+ +Expand source code + +
def convert_graph_to_undirected(
+    graph: DiGraph,
+    logging: bool = LOGGING_DEFAULT_GRAPHS,
+    cast_int: bool = False,
+) -> Graph:
+    dtype = np.float32
+    if cast_int:
+        dtype = np.uint32
+    # get adjacency matrix
+    adj_mat = typing.cast(DataFrame, nx.to_pandas_adjacency(G=graph, dtype=dtype))
+    arr = typing.cast(npt.NDArray[np.float32 | np.uint32], adj_mat.to_numpy())
+    if not cast_int:
+        arr = arr * (10**EDGE_WEIGHT_DECIMALS)
+        arr = np.round(arr, decimals=0)
+        arr = arr.astype(np.uint32)
+    # build undirected array: adding edges of lower triangular matrix to upper one
+    arr_upper = np.triu(arr)
+    arr_lower = np.tril(arr)
+    arr_lower = np.rot90(np.fliplr(arr_lower))
+    arr_new = arr_upper + arr_lower
+    if not cast_int:
+        arr_new = (arr_new / 10**EDGE_WEIGHT_DECIMALS).astype(np.float32)
+        arr_new = np.round(arr_new, decimals=EDGE_WEIGHT_DECIMALS)
+    # assign new data and create graph
+    adj_mat.loc[:] = arr_new  # type: ignore
+    graph_undir = typing.cast(Graph, nx.from_pandas_adjacency(df=adj_mat))
+
+    # info about graph
+    if logging:
+        logger.info('Successfully converted graph to one with undirected edges.')
+    _ = get_graph_metadata(graph=graph_undir, logging=logging)
+
+    return graph_undir
+
+
+
+
+def filter_graph_by_edge_weight(graph: TokenGraph,
bound_lower: int | None,
bound_upper: int | None,
property: str = 'weight') ‑> TokenGraph
+
+
+
+ +Expand source code + +
def filter_graph_by_edge_weight(
+    graph: TokenGraph,
+    bound_lower: int | None,
+    bound_upper: int | None,
+    property: str = 'weight',
+) -> TokenGraph:
+    """filters all edges which are within the provided bounds
+    inclusive limits: bound_lower <= edge_weight <= bound_upper are retained
+
+    Parameters
+    ----------
+    bound_lower : int | None
+        lower bound for edge weights, edges with weight equal to this value are retained
+    bound_upper : int | None
+        upper bound for edge weights, edges with weight equal to this value are retained
+
+    Returns
+    -------
+    TokenGraph
+        a copy of the graph with filtered edges
+    """
+    original_graph_edges = copy.deepcopy(graph.edges)
+    filtered_graph = graph.copy()
+
+    if not any((bound_lower, bound_upper)):
+        logger.warning('No bounds provided, returning original graph.')
+        return filtered_graph
+
+    for edge in original_graph_edges:
+        weight = typing.cast(int, filtered_graph[edge[0]][edge[1]][property])
+        if bound_lower is not None and weight < bound_lower:
+            filtered_graph.remove_edge(edge[0], edge[1])
+        if bound_upper is not None and weight > bound_upper:
+            filtered_graph.remove_edge(edge[0], edge[1])
+
+    filtered_graph.to_undirected(inplace=True, logging=False)
+    filtered_graph.update_metadata(logging=False)
+
+    return filtered_graph
+
+

filters all edges which are within the provided bounds +inclusive limits: bound_lower <= edge_weight <= bound_upper are retained

+

Parameters

+
+
bound_lower : int | None
+
lower bound for edge weights, edges with weight equal to this value are retained
+
bound_upper : int | None
+
upper bound for edge weights, edges with weight equal to this value are retained
+
+

Returns

+
+
TokenGraph
+
a copy of the graph with filtered edges
+
+
+
+def filter_graph_by_node_degree(graph: TokenGraph,
bound_lower: int | None,
bound_upper: int | None) ‑> TokenGraph
+
+
+
+ +Expand source code + +
def filter_graph_by_node_degree(
+    graph: TokenGraph,
+    bound_lower: int | None,
+    bound_upper: int | None,
+) -> TokenGraph:
+    """filters all nodes which are within the provided bounds by their degree,
+    inclusive limits: bound_lower <= node_degree <= bound_upper are retained
+
+    Parameters
+    ----------
+    bound_lower : int | None
+        lower bound for node degree, nodes with degree equal to this value are retained
+    bound_upper : int | None
+        upper bound for node degree, nodes with degree equal to this value are retained
+
+    Returns
+    -------
+    TokenGraph
+        a copy of the graph with filtered nodes
+    """
+    # filter nodes by degree
+    original_graph_nodes = copy.deepcopy(graph.nodes)
+    filtered_graph = graph.copy()
+    filtered_graph_degree = copy.deepcopy(filtered_graph.degree)
+
+    if not any([bound_lower, bound_upper]):
+        logger.warning('No bounds provided, returning original graph.')
+        return filtered_graph
+
+    for node in original_graph_nodes:
+        degree = cast(int, filtered_graph_degree[node])  # type: ignore
+        if bound_lower is not None and degree < bound_lower:
+            filtered_graph.remove_node(node)
+        if bound_upper is not None and degree > bound_upper:
+            filtered_graph.remove_node(node)
+
+    filtered_graph.to_undirected(inplace=True, logging=False)
+    filtered_graph.update_metadata(logging=False)
+
+    return filtered_graph
+
+

filters all nodes which are within the provided bounds by their degree, +inclusive limits: bound_lower <= node_degree <= bound_upper are retained

+

Parameters

+
+
bound_lower : int | None
+
lower bound for node degree, nodes with degree equal to this value are retained
+
bound_upper : int | None
+
upper bound for node degree, nodes with degree equal to this value are retained
+
+

Returns

+
+
TokenGraph
+
a copy of the graph with filtered nodes
+
+
+
+def filter_graph_by_number_edges(graph: TokenGraph,
limit: int | None,
property: str = 'weight',
descending: bool = True) ‑> TokenGraph
+
+
+
+ +Expand source code + +
def filter_graph_by_number_edges(
+    graph: TokenGraph,
+    limit: int | None,
+    property: str = 'weight',
+    descending: bool = True,
+) -> TokenGraph:
+    graph = graph.copy()
+    # edges
+    original = set(graph.edges(data=property))  # type: ignore
+    original_sorted = sorted(original, key=lambda tup: tup[2], reverse=descending)
+    if limit is not None:
+        chosen = set(original_sorted[:limit])
+    else:
+        chosen = set(original_sorted)
+    edges_to_drop = original.difference(chosen)
+    graph.remove_edges_from(edges_to_drop)
+
+    return graph
+
+
+
+
+def get_graph_metadata(graph: Graph | DiGraph, logging: bool = False) ‑> dict[str, float] +
+
+
+ +Expand source code + +
def get_graph_metadata(
+    graph: Graph | DiGraph,
+    logging: bool = LOGGING_DEFAULT_GRAPHS,
+) -> dict[str, float]:
+    # info about graph
+    graph_info: dict[str, float] = {}
+    # nodes and edges
+    num_nodes = len(graph.nodes)
+    num_edges = len(graph.edges)
+    # edge weights
+    min_edge_weight: int = 1_000_000
+    max_edge_weight: int = 0
+    for edge in graph.edges:
+        weight = typing.cast(int, graph[edge[0]][edge[1]]['weight'])
+        if weight < min_edge_weight:
+            min_edge_weight = weight
+        if weight > max_edge_weight:
+            max_edge_weight = weight
+
+    # memory
+    edge_mem = sum([sys.getsizeof(e) for e in graph.edges])
+    node_mem = sum([sys.getsizeof(n) for n in graph.nodes])
+    total_mem = edge_mem + node_mem
+
+    graph_info.update(
+        num_nodes=num_nodes,
+        num_edges=num_edges,
+        min_edge_weight=min_edge_weight,
+        max_edge_weight=max_edge_weight,
+        node_memory=node_mem,
+        edge_memory=edge_mem,
+        total_memory=total_mem,
+    )
+
+    if logging:
+        logger.info('Graph properties: %d Nodes, %d Edges', num_nodes, num_edges)
+        logger.info('Node memory: %.2f KB', (node_mem / 1024))
+        logger.info('Edge memory: %.2f KB', (edge_mem / 1024))
+        logger.info('Total memory: %.2f KB', (total_mem / 1024))
+
+    return graph_info
+
+
+
+
+def normalise_array_linear(array: npt.NDArray[np.float32]) ‑> numpy.ndarray[typing.Any, numpy.dtype[numpy.float32]] +
+
+
+ +Expand source code + +
def normalise_array_linear(
+    array: npt.NDArray[np.float32],
+) -> npt.NDArray[np.float32]:
+    """apply standard linear normalisation
+
+    Parameters
+    ----------
+    array : npt.NDArray[np.float_]
+        array which shall be normalised
+
+    Returns
+    -------
+    npt.NDArray[np.float32]
+        min/max normalised array
+    """
+    div = array.max() - array.min()
+    if div != 0:
+        arr_norm = (array - array.min()) / div
+        return arr_norm.astype(np.float32)
+    else:
+        return np.zeros(shape=array.shape, dtype=np.float32)
+
+

apply standard linear normalisation

+

Parameters

+
+
array : npt.NDArray[np.float_]
+
array which shall be normalised
+
+

Returns

+
+
npt.NDArray[np.float32]
+
min/max normalised array
+
+
+
+def pipe_add_graph_metrics(*graphs: DiGraph | Graph) ‑> tuple[networkx.classes.digraph.DiGraph | networkx.classes.graph.Graph, ...] +
+
+
+ +Expand source code + +
def pipe_add_graph_metrics(
+    *graphs: DiGraph | Graph,
+) -> tuple[DiGraph | Graph, ...]:
+    collection: list[DiGraph | Graph] = []
+    for graph in graphs:
+        graph_copy = copy.deepcopy(graph)
+        add_weighted_degree(graph_copy)
+        add_betweenness_centrality(graph_copy)
+        add_importance_metric(graph_copy)
+        collection.append(graph_copy)
+
+    return tuple(collection)
+
+
+
+
+def pipe_rescale_graph_edge_weights(graph: TokenGraph) ‑> tuple[TokenGraph, networkx.classes.graph.Graph] +
+
+
+ +Expand source code + +
def pipe_rescale_graph_edge_weights(
+    graph: TokenGraph,
+) -> tuple[TokenGraph, Graph]:
+    """helper function to allow calls in pipelines
+
+    Parameters
+    ----------
+    graph : TokenGraph
+        token graph pushed through pipeline
+
+    Returns
+    -------
+    tuple[TokenGraph, Graph]
+        token graph (directed) and undirected version with rescaled edge weights
+    """
+    graph = graph.copy()
+
+    return graph.rescale_edge_weights()
+
+

helper function to allow calls in pipelines

+

Parameters

+
+
graph : TokenGraph
+
token graph pushed through pipeline
+
+

Returns

+
+
tuple[TokenGraph, Graph]
+
token graph (directed) and undirected version with rescaled edge weights
+
+
+
+def rescale_edge_weights(graph: Graph | DiGraph | TokenGraph,
weight_property: str = 'weight') ‑> networkx.classes.graph.Graph | networkx.classes.digraph.DiGraph | TokenGraph
+
+
+
+ +Expand source code + +
def rescale_edge_weights(
+    graph: Graph | DiGraph | TokenGraph,
+    weight_property: str = 'weight',
+) -> Graph | DiGraph | TokenGraph:
+    graph = graph.copy()
+    # check non-emptiness
+    verify_non_empty_graph(graph, including_edges=True)
+    # check if all edges contain weight property
+    verify_property(graph, property=weight_property)
+
+    weights = cast(list[int], [data['weight'] for data in graph.edges.values()])
+    w_log = cast(npt.NDArray[np.float32], np.log(weights, dtype=np.float32))
+    weights_norm = normalise_array_linear(w_log)
+    weights_adjusted = weight_scaling(weights_norm)
+    # assign new weight values
+    for idx, (node_1, node_2) in enumerate(graph.edges):
+        graph[node_1][node_2]['weight'] = weights_adjusted[idx]
+
+    return graph
+
+
+
+
+def save_to_GraphML(graph: DiGraph | Graph, saving_path: Path, filename: str | None = None) ‑> None +
+
+
+ +Expand source code + +
def save_to_GraphML(
+    graph: DiGraph | Graph,
+    saving_path: Path,
+    filename: str | None = None,
+) -> None:
+    if filename is not None:
+        saving_path = saving_path.joinpath(filename)
+    saving_path = saving_path.with_suffix('.graphml')
+    nx.write_graphml(G=graph, path=saving_path)
+    logger.info('Successfully saved graph as GraphML file under %s.', saving_path)
+
+
+
+
+def static_graph_analysis(graph: TokenGraph) ‑> tuple[TokenGraph] +
+
+
+ +Expand source code + +
def static_graph_analysis(
+    graph: TokenGraph,
+) -> tuple[TokenGraph]:
+    """helper function to allow the calculation of static metrics in pipelines
+
+    Parameters
+    ----------
+    tk_graph_directed : TokenGraph
+        token graph (directed)
+
+    Returns
+    -------
+    tuple[TokenGraph]
+        token graph (directed) with included undirected version and calculated KPIs
+    """
+    graph = graph.copy()
+    graph.perform_static_analysis()
+
+    return (graph,)
+
+

helper function to allow the calculation of static metrics in pipelines

+

Parameters

+
+
tk_graph_directed : TokenGraph
+
token graph (directed)
+
+

Returns

+
+
tuple[TokenGraph]
+
token graph (directed) with included undirected version and calculated KPIs
+
+
+
+def update_graph(graph: Graph | DiGraph,
*,
batch: Iterable[tuple[Hashable, Hashable]] | None = None,
parent: Hashable | None = None,
child: Hashable | None = None,
weight_connection: int | None = None) ‑> None
+
+
+
+ +Expand source code + +
def update_graph(
+    graph: Graph | DiGraph,
+    *,
+    batch: Iterable[tuple[Hashable, Hashable]] | None = None,
+    parent: Hashable | None = None,
+    child: Hashable | None = None,
+    weight_connection: int | None = None,
+) -> None:
+    if weight_connection is None:
+        weight_connection = 1
+    # check if edge not in Graph
+    if batch is not None:
+        graph.add_edges_from(batch, weight=weight_connection)
+    elif not graph.has_edge(parent, child):
+        # create new edge, nodes will be created if not already present
+        graph.add_edge(parent, child, weight=weight_connection)
+    else:
+        # update edge
+        graph[parent][child]['weight'] += weight_connection
+
+
+
+
+def verify_non_empty_graph(graph: DiGraph | Graph, including_edges: bool = True) ‑> None +
+
+
+ +Expand source code + +
def verify_non_empty_graph(
+    graph: DiGraph | Graph,
+    including_edges: bool = True,
+) -> None:
+    """check if the given graph is empty, presence of nodes is checked first,
+    then of edges
+
+    Parameters
+    ----------
+    graph : DiGraph | Graph
+        graph to check for emptiness
+    including_edges : bool, optional
+        whether to check for non-existence of edges, by default True
+
+    Raises
+    ------
+    EmptyGraphError
+        if graph does not contain any nodes and therefore edges
+    EmptyEdgesError
+        if graph does not contain any edges
+    """
+    if not tuple(graph.nodes):
+        raise EmptyGraphError(f'Graph object >>{graph}<< does not contain any nodes.')
+    elif including_edges and not tuple(graph.edges):
+        raise EmptyEdgesError(f'Graph object >>{graph}<< does not contain any edges.')
+
+

check if the given graph is empty, presence of nodes is checked first, +then of edges

+

Parameters

+
+
graph : DiGraph | Graph
+
graph to check for emptiness
+
including_edges : bool, optional
+
whether to check for non-existence of edges, by default True
+
+

Raises

+
+
EmptyGraphError
+
if graph does not contain any nodes and therefore edges
+
EmptyEdgesError
+
if graph does not contain any edges
+
+
+
+def verify_property(graph: Graph | DiGraph, property: str) ‑> None +
+
+
+ +Expand source code + +
def verify_property(
+    graph: Graph | DiGraph,
+    property: str,
+) -> None:
+    for node_1, node_2 in graph.edges:
+        if property not in graph[node_1][node_2]:
+            raise EdgePropertyNotContainedError(
+                (
+                    f'Edge property >>{property}<< not '
+                    f'available for edge >>({node_1}, {node_2})<<'
+                )
+            )
+
+
+
+
+def weight_scaling(weights: npt.NDArray[np.float32], a: float = 1.1, b: float = 0.05) ‑> numpy.ndarray[typing.Any, numpy.dtype[numpy.float32]] +
+
+
+ +Expand source code + +
def weight_scaling(
+    weights: npt.NDArray[np.float32],
+    a: float = 1.1,
+    b: float = 0.05,
+) -> npt.NDArray[np.float32]:
+    """non-linear scaling of already normalised edge weights [0;1]: bigger weights
+    have smaller weight delta than smaller weights. Bigger values für parameter
+    `b` reinforce this effect.
+    Based on:
+    https://math.stackexchange.com/questions/4297805/exponential-function-that-passes-through-0-0-and-1-1-with-variable-slope
+
+    With default values the range of edge weights lies approximately between [0.1; 1]
+
+    Parameters
+    ----------
+    weights : npt.NDArray[np.float32]
+        pre-normalised edge weights as 1D array
+    a : float, optional
+        factor to determine the value for edge weights with value 0
+        with default approx. 0.1, by default 1.1
+    b : float, optional
+        adjust the curvature, smaller values increase it, by default 0.05
+
+    Returns
+    -------
+    npt.NDArray[np.float32]
+        non-linear adjusted edge weights as 1D array
+    """
+    adjusted_weights = (b**weights - a) / (b - a)
+
+    return np.round(adjusted_weights, decimals=EDGE_WEIGHT_DECIMALS)
+
+

non-linear scaling of already normalised edge weights [0;1]: bigger weights +have smaller weight delta than smaller weights. Bigger values für parameter +b reinforce this effect. +Based on: +https://math.stackexchange.com/questions/4297805/exponential-function-that-passes-through-0-0-and-1-1-with-variable-slope

+

With default values the range of edge weights lies approximately between [0.1; 1]

+

Parameters

+
+
weights : npt.NDArray[np.float32]
+
pre-normalised edge weights as 1D array
+
a : float, optional
+
factor to determine the value for edge weights with value 0 +with default approx. 0.1, by default 1.1
+
b : float, optional
+
adjust the curvature, smaller values increase it, by default 0.05
+
+

Returns

+
+
npt.NDArray[np.float32]
+
non-linear adjusted edge weights as 1D array
+
+
+
+
+
+

Classes

+
+
+class TokenGraph +(name: str = 'TokenGraph',
enable_logging: bool = True,
incoming_graph_data: Any | None = None,
**attr)
+
+
+
+ +Expand source code + +
class TokenGraph(DiGraph):
+    def __init__(
+        self,
+        name: str = 'TokenGraph',
+        enable_logging: bool = True,
+        incoming_graph_data: Any | None = None,
+        **attr,
+    ) -> None:
+        super().__init__(incoming_graph_data, **attr)
+        # logging of different actions
+        self.logging = enable_logging
+        # properties
+        self._name = name
+        # directed and undirected graph data
+        self._directed = self
+        self._metadata_directed: dict[str, float] = {}
+        self._undirected: Graph | None = None
+        self._metadata_undirected: dict[str, float] = {}
+        # indicate rescaled weights
+        self.rescaled_weights: bool = False
+
+    def __repr__(self) -> str:
+        return self.__str__()
+
+    def __str__(self) -> str:
+        return (
+            f'TokenGraph(name: {self.name}, number of nodes: '
+            f'{len(self.nodes)}, number of edges: '
+            f'{len(self.edges)})'
+        )
+
+    def disable_logging(self) -> None:
+        self.logging = False
+
+    # !! only used to verify that saving was done correctly
+    """
+    def __key(self) -> tuple[Hashable, ...]:
+        return (self.name, tuple(self.nodes), tuple(self.edges))
+    
+    def __hash__(self) -> int:
+        return hash(self.__key())
+    """
+
+    def copy(self) -> Self:
+        """returns a (deep) copy of the graph
+
+        Returns
+        -------
+        Self
+            deep copy of the graph
+        """
+        return copy.deepcopy(self)
+
+    @property
+    def name(self) -> str:
+        return self._name
+
+    @property
+    def directed(self) -> Self:
+        return self._directed
+
+    @property
+    def undirected(self) -> Graph:
+        if self._undirected is None:
+            self._undirected = self.to_undirected(inplace=False, logging=False)
+
+        return self._undirected
+
+    @property
+    def metadata_directed(self) -> dict[str, float]:
+        return self._metadata_directed
+
+    @property
+    def metadata_undirected(self) -> dict[str, float]:
+        return self._metadata_undirected
+
+    @overload
+    def to_undirected(
+        self,
+        inplace: Literal[True] = ...,
+        logging: bool | None = ...,
+    ) -> None: ...
+
+    @overload
+    def to_undirected(
+        self,
+        inplace: Literal[False],
+        logging: bool | None = ...,
+    ) -> Graph: ...
+
+    def to_undirected(
+        self,
+        inplace: bool = True,
+        logging: bool | None = None,
+    ) -> Graph | None:
+        if logging is None:
+            logging = self.logging
+        # cast to integer edge weights only if edges were not rescaled previously
+        cast_int: bool = True
+        if self.rescaled_weights:
+            cast_int = False
+
+        self._undirected = convert_graph_to_undirected(
+            graph=self,
+            logging=logging,
+            cast_int=cast_int,
+        )
+        self._metadata_undirected = get_graph_metadata(graph=self._undirected, logging=False)
+        if not inplace:
+            return self._undirected
+
+    def update_metadata(
+        self,
+        logging: bool | None = None,
+    ) -> None:
+        if logging is None:
+            logging = self.logging
+
+        self._metadata_directed = get_graph_metadata(graph=self, logging=logging)
+        if self._undirected is not None:
+            self._metadata_undirected = get_graph_metadata(
+                graph=self._undirected, logging=logging
+            )
+
+    def rescale_edge_weights(
+        self,
+    ) -> tuple[TokenGraph, Graph]:
+        """generate new instances of the directed and undirected TokenGraph with
+        rescaled edge weights
+        Only this method ensures that undirected graphs are scaled properly. If
+        the underlying `to_undirected` method of the directed and rescaled
+        TokenGraph instance is called the weights are not rescaled again. Thus,
+        the maximum edge weight can exceed the theoretical maximum value of 1. To
+        ensure consistent behaviour across different applications of the conversion to
+        undirected graphs new instances are returned, especially for the undirected
+        graph.
+        In contrast, the new directed TokenGraph contains an undirected version without
+        rescaling of the weights. Therefore, this undirected version differs from the version
+        returned by this method.
+
+        Returns
+        -------
+        tuple[TokenGraph, Graph]
+            directed and undirected instances
+        """
+        self.to_undirected(inplace=True, logging=False)
+        token_graph = rescale_edge_weights(self.directed)
+        token_graph.rescaled_weights = True
+        token_graph.update_metadata(logging=False)
+        undirected = rescale_edge_weights(self.undirected)
+
+        return token_graph, undirected
+
+    def perform_static_analysis(self) -> None:
+        """calculate different metrics directly on the data of the underlying graphs
+        (directed and undirected)
+
+        Current operations:
+            - adding weighted degree
+        """
+        add_weighted_degree(self)
+        add_weighted_degree(self.undirected)
+
+    def _save_prepare(
+        self,
+        path: Path,
+        filename: str | None = None,
+    ) -> Path:
+        if filename is not None:
+            saving_path = path.joinpath(f'{filename}')
+        else:
+            saving_path = path.joinpath(f'{self.name}')
+
+        return saving_path
+
+    def to_GraphML(
+        self,
+        path: Path,
+        filename: str | None = None,
+        directed: bool = False,
+    ) -> None:
+        """save one of the stored graphs to GraphML format on disk,
+
+        Parameters
+        ----------
+        path : Path
+            target path for saving the file
+        filename : str | None, optional
+            filename to be given, by default None
+        directed : bool, optional
+            indicator whether directed or undirected graph
+            should be exported, by default False (undirected)
+
+        Raises
+        ------
+        ValueError
+            undirected graph should be exported but is not available
+        """
+        saving_path = self._save_prepare(path=path, filename=filename)
+
+        if directed:
+            target_graph = self.directed
+        else:
+            target_graph = self.undirected
+
+        save_to_GraphML(graph=target_graph, saving_path=saving_path)
+
+    def to_pickle(
+        self,
+        path: Path,
+        filename: str | None = None,
+    ) -> None:
+        """save whole TokenGraph object as pickle file
+
+        Parameters
+        ----------
+        path : Path
+            target path for saving the file
+        filename : str | None, optional
+            filename to be given, by default None
+        """
+        saving_path = self._save_prepare(path=path, filename=filename)
+        saving_path = saving_path.with_suffix('.pkl')
+        save_pickle(obj=self, path=saving_path)
+
+    @classmethod
+    def from_file(
+        cls,
+        path: Path,
+        node_type_graphml: type = str,
+    ) -> Self:
+        # !! no validity checks for pickle files
+        # !! GraphML files not correct because not all properties
+        # !! are parsed correctly
+        # TODO REWORK
+        match path.suffix:
+            case '.graphml':
+                graph = typing.cast(Self, nx.read_graphml(path, node_type=node_type_graphml))
+                logger.info('Successfully loaded graph from GraphML file %s.', path)
+            case '.pkl' | '.pickle':
+                graph = typing.cast(Self, load_pickle(path))
+                logger.info('Successfully loaded graph from pickle file %s.', path)
+            case _:
+                raise ValueError('File format not supported.')
+
+        return graph
+
+

Base class for directed graphs.

+

A DiGraph stores nodes and edges with optional data, or attributes.

+

DiGraphs hold directed edges. +Self loops are allowed but multiple +(parallel) edges are not.

+

Nodes can be arbitrary (hashable) Python objects with optional +key/value attributes. By convention None is not used as a node.

+

Edges are represented as links between nodes with optional +key/value attributes.

+

Parameters

+
+
incoming_graph_data : input graph (optional, default: None)
+
Data to initialize graph. If None (default) an empty +graph is created. +The data can be any format that is supported +by the to_networkx_graph() function, currently including edge list, +dict of dicts, dict of lists, NetworkX graph, 2D NumPy array, SciPy +sparse matrix, or PyGraphviz graph.
+
attr : keyword arguments, optional (default= no attributes)
+
Attributes to add to graph as key=value pairs.
+
+

See Also

+

Graph +MultiGraph +MultiDiGraph

+

Examples

+

Create an empty graph structure (a "null graph") with no nodes and +no edges.

+
>>> G = nx.DiGraph()
+
+

G can be grown in several ways.

+

Nodes:

+

Add one node at a time:

+
>>> G.add_node(1)
+
+

Add the nodes from any container (a list, dict, set or +even the lines from a file or the nodes from another graph).

+
>>> G.add_nodes_from([2, 3])
+>>> G.add_nodes_from(range(100, 110))
+>>> H = nx.path_graph(10)
+>>> G.add_nodes_from(H)
+
+

In addition to strings and integers any hashable Python object +(except None) can represent a node, e.g. a customized node object, +or even another Graph.

+
>>> G.add_node(H)
+
+

Edges:

+

G can also be grown by adding edges.

+

Add one edge,

+
>>> G.add_edge(1, 2)
+
+

a list of edges,

+
>>> G.add_edges_from([(1, 2), (1, 3)])
+
+

or a collection of edges,

+
>>> G.add_edges_from(H.edges)
+
+

If some edges connect nodes not yet in the graph, the nodes +are added automatically. +There are no errors when adding +nodes or edges that already exist.

+

Attributes:

+

Each graph, node, and edge can hold key/value attribute pairs +in an associated attribute dictionary (the keys must be hashable). +By default these are empty, but can be added or changed using +add_edge, add_node or direct manipulation of the attribute +dictionaries named graph, node and edge respectively.

+
>>> G = nx.DiGraph(day="Friday")
+>>> G.graph
+{'day': 'Friday'}
+
+

Add node attributes using add_node(), add_nodes_from() or G.nodes

+
>>> G.add_node(1, time="5pm")
+>>> G.add_nodes_from([3], time="2pm")
+>>> G.nodes[1]
+{'time': '5pm'}
+>>> G.nodes[1]["room"] = 714
+>>> del G.nodes[1]["room"]  # remove attribute
+>>> list(G.nodes(data=True))
+[(1, {'time': '5pm'}), (3, {'time': '2pm'})]
+
+

Add edge attributes using add_edge(), add_edges_from(), subscript +notation, or G.edges.

+
>>> G.add_edge(1, 2, weight=4.7)
+>>> G.add_edges_from([(3, 4), (4, 5)], color="red")
+>>> G.add_edges_from([(1, 2, {"color": "blue"}), (2, 3, {"weight": 8})])
+>>> G[1][2]["weight"] = 4.7
+>>> G.edges[1, 2]["weight"] = 4
+
+

Warning: we protect the graph data structure by making G.edges[1, 2] a +read-only dict-like structure. However, you can assign to attributes +in e.g. G.edges[1, 2]. Thus, use 2 sets of brackets to add/change +data attributes: G.edges[1, 2]['weight'] = 4 +(For multigraphs: MG.edges[u, v, key][name] = value).

+

Shortcuts:

+

Many common graph features allow python syntax to speed reporting.

+
>>> 1 in G  # check if node in graph
+True
+>>> [n for n in G if n < 3]  # iterate through nodes
+[1, 2]
+>>> len(G)  # number of nodes in graph
+5
+
+

Often the best way to traverse all edges of a graph is via the neighbors. +The neighbors are reported as an adjacency-dict G.adj or G.adjacency()

+
>>> for n, nbrsdict in G.adjacency():
+...     for nbr, eattr in nbrsdict.items():
+...         if "weight" in eattr:
+...             # Do something useful with the edges
+...             pass
+
+

But the edges reporting object is often more convenient:

+
>>> for u, v, weight in G.edges(data="weight"):
+...     if weight is not None:
+...         # Do something useful with the edges
+...         pass
+
+

Reporting:

+

Simple graph information is obtained using object-attributes and methods. +Reporting usually provides views instead of containers to reduce memory +usage. The views update as the graph is updated similarly to dict-views. +The objects nodes, edges and adj provide access to data attributes +via lookup (e.g. nodes[n], edges[u, v], adj[u][v]) and iteration +(e.g. nodes.items(), nodes.data('color'), +nodes.data('color', default='blue') and similarly for edges) +Views exist for nodes, edges, neighbors()/adj and degree.

+

For details on these and other miscellaneous methods, see below.

+

Subclasses (Advanced):

+

The Graph class uses a dict-of-dict-of-dict data structure. +The outer dict (node_dict) holds adjacency information keyed by node. +The next dict (adjlist_dict) represents the adjacency information and holds +edge data keyed by neighbor. +The inner dict (edge_attr_dict) represents +the edge data and holds edge attribute values keyed by attribute names.

+

Each of these three dicts can be replaced in a subclass by a user defined +dict-like object. In general, the dict-like features should be +maintained but extra features can be added. To replace one of the +dicts create a new graph class by changing the class(!) variable +holding the factory for that dict-like structure. The variable names are +node_dict_factory, node_attr_dict_factory, adjlist_inner_dict_factory, +adjlist_outer_dict_factory, edge_attr_dict_factory and graph_attr_dict_factory.

+

node_dict_factory : function, (default: dict) +Factory function to be used to create the dict containing node +attributes, keyed by node id. +It should require no arguments and return a dict-like object

+

node_attr_dict_factory: function, (default: dict) +Factory function to be used to create the node attribute +dict which holds attribute values keyed by attribute name. +It should require no arguments and return a dict-like object

+

adjlist_outer_dict_factory : function, (default: dict) +Factory function to be used to create the outer-most dict +in the data structure that holds adjacency info keyed by node. +It should require no arguments and return a dict-like object.

+

adjlist_inner_dict_factory : function, optional (default: dict) +Factory function to be used to create the adjacency list +dict which holds edge data keyed by neighbor. +It should require no arguments and return a dict-like object

+

edge_attr_dict_factory : function, optional (default: dict) +Factory function to be used to create the edge attribute +dict which holds attribute values keyed by attribute name. +It should require no arguments and return a dict-like object.

+

graph_attr_dict_factory : function, (default: dict) +Factory function to be used to create the graph attribute +dict which holds attribute values keyed by attribute name. +It should require no arguments and return a dict-like object.

+

Typically, if your extension doesn't impact the data structure all +methods will inherited without issue except: to_directed/to_undirected. +By default these methods create a DiGraph/Graph class and you probably +want them to create your extension of a DiGraph/Graph. To facilitate +this we define two class variables that you can set in your subclass.

+

to_directed_class : callable, (default: DiGraph or MultiDiGraph) +Class to create a new graph structure in the to_directed method. +If None, a NetworkX class (DiGraph or MultiDiGraph) is used.

+

to_undirected_class : callable, (default: Graph or MultiGraph) +Class to create a new graph structure in the to_undirected method. +If None, a NetworkX class (Graph or MultiGraph) is used.

+

Subclassing Example

+

Create a low memory graph class that effectively disallows edge +attributes by using a single attribute dict for all edges. +This reduces the memory used, but you lose edge attributes.

+
>>> class ThinGraph(nx.Graph):
+...     all_edge_dict = {"weight": 1}
+...
+...     def single_edge_dict(self):
+...         return self.all_edge_dict
+...
+...     edge_attr_dict_factory = single_edge_dict
+>>> G = ThinGraph()
+>>> G.add_edge(2, 1)
+>>> G[2][1]
+{'weight': 1}
+>>> G.add_edge(2, 2)
+>>> G[2][1] is G[2][2]
+True
+
+

Initialize a graph with edges, name, or graph attributes.

+

Parameters

+
+
incoming_graph_data : input graph (optional, default: None)
+
Data to initialize graph. +If None (default) an empty +graph is created. +The data can be an edge list, or any +NetworkX graph object. +If the corresponding optional Python +packages are installed the data can also be a 2D NumPy array, a +SciPy sparse array, or a PyGraphviz graph.
+
attr : keyword arguments, optional (default= no attributes)
+
Attributes to add to graph as key=value pairs.
+
+

See Also

+

convert

+

Examples

+
>>> G = nx.Graph()  # or DiGraph, MultiGraph, MultiDiGraph, etc
+>>> G = nx.Graph(name="my graph")
+>>> e = [(1, 2), (2, 3), (3, 4)]  # list of edges
+>>> G = nx.Graph(e)
+
+

Arbitrary graph attribute pairs (key=value) may be assigned

+
>>> G = nx.Graph(e, day="Friday")
+>>> G.graph
+{'day': 'Friday'}
+
+

Ancestors

+
    +
  • networkx.classes.digraph.DiGraph
  • +
  • networkx.classes.graph.Graph
  • +
+

Static methods

+
+
+def from_file(path: Path, node_type_graphml: type = builtins.str) ‑> Self +
+
+
+
+
+

Instance variables

+
+
prop directed : Self
+
+
+ +Expand source code + +
@property
+def directed(self) -> Self:
+    return self._directed
+
+
+
+
prop metadata_directed : dict[str, float]
+
+
+ +Expand source code + +
@property
+def metadata_directed(self) -> dict[str, float]:
+    return self._metadata_directed
+
+
+
+
prop metadata_undirected : dict[str, float]
+
+
+ +Expand source code + +
@property
+def metadata_undirected(self) -> dict[str, float]:
+    return self._metadata_undirected
+
+
+
+
prop name : str
+
+
+ +Expand source code + +
@property
+def name(self) -> str:
+    return self._name
+
+

String identifier of the graph.

+

This graph attribute appears in the attribute dict G.graph +keyed by the string "name". as well as an attribute (technically +a property) G.name. This is entirely user controlled.

+
+
prop undirected : Graph
+
+
+ +Expand source code + +
@property
+def undirected(self) -> Graph:
+    if self._undirected is None:
+        self._undirected = self.to_undirected(inplace=False, logging=False)
+
+    return self._undirected
+
+
+
+
+

Methods

+
+
+def copy(self) ‑> Self +
+
+
+ +Expand source code + +
def copy(self) -> Self:
+    """returns a (deep) copy of the graph
+
+    Returns
+    -------
+    Self
+        deep copy of the graph
+    """
+    return copy.deepcopy(self)
+
+

returns a (deep) copy of the graph

+

Returns

+
+
Self
+
deep copy of the graph
+
+
+
+def disable_logging(self) ‑> None +
+
+
+ +Expand source code + +
def disable_logging(self) -> None:
+    self.logging = False
+
+
+
+
+def perform_static_analysis(self) ‑> None +
+
+
+ +Expand source code + +
def perform_static_analysis(self) -> None:
+    """calculate different metrics directly on the data of the underlying graphs
+    (directed and undirected)
+
+    Current operations:
+        - adding weighted degree
+    """
+    add_weighted_degree(self)
+    add_weighted_degree(self.undirected)
+
+

calculate different metrics directly on the data of the underlying graphs +(directed and undirected)

+

Current operations: +- adding weighted degree

+
+
+def rescale_edge_weights(self) ‑> tuple[TokenGraph, networkx.classes.graph.Graph] +
+
+
+ +Expand source code + +
def rescale_edge_weights(
+    self,
+) -> tuple[TokenGraph, Graph]:
+    """generate new instances of the directed and undirected TokenGraph with
+    rescaled edge weights
+    Only this method ensures that undirected graphs are scaled properly. If
+    the underlying `to_undirected` method of the directed and rescaled
+    TokenGraph instance is called the weights are not rescaled again. Thus,
+    the maximum edge weight can exceed the theoretical maximum value of 1. To
+    ensure consistent behaviour across different applications of the conversion to
+    undirected graphs new instances are returned, especially for the undirected
+    graph.
+    In contrast, the new directed TokenGraph contains an undirected version without
+    rescaling of the weights. Therefore, this undirected version differs from the version
+    returned by this method.
+
+    Returns
+    -------
+    tuple[TokenGraph, Graph]
+        directed and undirected instances
+    """
+    self.to_undirected(inplace=True, logging=False)
+    token_graph = rescale_edge_weights(self.directed)
+    token_graph.rescaled_weights = True
+    token_graph.update_metadata(logging=False)
+    undirected = rescale_edge_weights(self.undirected)
+
+    return token_graph, undirected
+
+

generate new instances of the directed and undirected TokenGraph with +rescaled edge weights +Only this method ensures that undirected graphs are scaled properly. If +the underlying to_undirected method of the directed and rescaled +TokenGraph instance is called the weights are not rescaled again. Thus, +the maximum edge weight can exceed the theoretical maximum value of 1. To +ensure consistent behaviour across different applications of the conversion to +undirected graphs new instances are returned, especially for the undirected +graph. +In contrast, the new directed TokenGraph contains an undirected version without +rescaling of the weights. Therefore, this undirected version differs from the version +returned by this method.

+

Returns

+
+
tuple[TokenGraph, Graph]
+
directed and undirected instances
+
+
+
+def to_GraphML(self, path: Path, filename: str | None = None, directed: bool = False) ‑> None +
+
+
+ +Expand source code + +
def to_GraphML(
+    self,
+    path: Path,
+    filename: str | None = None,
+    directed: bool = False,
+) -> None:
+    """save one of the stored graphs to GraphML format on disk,
+
+    Parameters
+    ----------
+    path : Path
+        target path for saving the file
+    filename : str | None, optional
+        filename to be given, by default None
+    directed : bool, optional
+        indicator whether directed or undirected graph
+        should be exported, by default False (undirected)
+
+    Raises
+    ------
+    ValueError
+        undirected graph should be exported but is not available
+    """
+    saving_path = self._save_prepare(path=path, filename=filename)
+
+    if directed:
+        target_graph = self.directed
+    else:
+        target_graph = self.undirected
+
+    save_to_GraphML(graph=target_graph, saving_path=saving_path)
+
+

save one of the stored graphs to GraphML format on disk,

+

Parameters

+
+
path : Path
+
target path for saving the file
+
filename : str | None, optional
+
filename to be given, by default None
+
directed : bool, optional
+
indicator whether directed or undirected graph +should be exported, by default False (undirected)
+
+

Raises

+
+
ValueError
+
undirected graph should be exported but is not available
+
+
+
+def to_pickle(self, path: Path, filename: str | None = None) ‑> None +
+
+
+ +Expand source code + +
def to_pickle(
+    self,
+    path: Path,
+    filename: str | None = None,
+) -> None:
+    """save whole TokenGraph object as pickle file
+
+    Parameters
+    ----------
+    path : Path
+        target path for saving the file
+    filename : str | None, optional
+        filename to be given, by default None
+    """
+    saving_path = self._save_prepare(path=path, filename=filename)
+    saving_path = saving_path.with_suffix('.pkl')
+    save_pickle(obj=self, path=saving_path)
+
+

save whole TokenGraph object as pickle file

+

Parameters

+
+
path : Path
+
target path for saving the file
+
filename : str | None, optional
+
filename to be given, by default None
+
+
+
+def to_undirected(self, inplace: bool = True, logging: bool | None = None) ‑> networkx.classes.graph.Graph | None +
+
+
+ +Expand source code + +
def to_undirected(
+    self,
+    inplace: bool = True,
+    logging: bool | None = None,
+) -> Graph | None:
+    if logging is None:
+        logging = self.logging
+    # cast to integer edge weights only if edges were not rescaled previously
+    cast_int: bool = True
+    if self.rescaled_weights:
+        cast_int = False
+
+    self._undirected = convert_graph_to_undirected(
+        graph=self,
+        logging=logging,
+        cast_int=cast_int,
+    )
+    self._metadata_undirected = get_graph_metadata(graph=self._undirected, logging=False)
+    if not inplace:
+        return self._undirected
+
+

Returns an undirected representation of the digraph.

+

Parameters

+
+
reciprocal : bool (optional)
+
 
+
If True only keep edges that appear in both directions
+
in the original digraph.
+
as_view : bool (optional, default=False)
+
 
+
+

If True return an undirected view of the original directed graph.

+

Returns

+
+
G : Graph
+
An undirected graph with the same name and nodes and +with edge (u, v, data) if either (u, v, data) or (v, u, data) +is in the digraph. +If both edges exist in digraph and +their edge data is different, only one edge is created +with an arbitrary choice of which edge data to use. +You must check and correct for this manually if desired.
+
+

See Also

+

Graph, copy, add_edge, add_edges_from

+

Notes

+

If edges in both directions (u, v) and (v, u) exist in the +graph, attributes for the new undirected edge will be a combination of +the attributes of the directed edges. +The edge data is updated +in the (arbitrary) order that the edges are encountered. +For +more customized control of the edge attributes use add_edge().

+

This returns a "deepcopy" of the edge, node, and +graph attributes which attempts to completely copy +all of the data and references.

+

This is in contrast to the similar G=DiGraph(D) which returns a +shallow copy of the data.

+

See the Python copy module for more information on shallow +and deep copies, https://docs.python.org/3/library/copy.html.

+

Warning: If you have subclassed DiGraph to use dict-like objects +in the data structure, those changes do not transfer to the +Graph created by this method.

+

Examples

+
>>> G = nx.path_graph(2)  # or MultiGraph, etc
+>>> H = G.to_directed()
+>>> list(H.edges)
+[(0, 1), (1, 0)]
+>>> G2 = H.to_undirected()
+>>> list(G2.edges)
+[(0, 1)]
+
+
+
+def update_metadata(self, logging: bool | None = None) ‑> None +
+
+
+ +Expand source code + +
def update_metadata(
+    self,
+    logging: bool | None = None,
+) -> None:
+    if logging is None:
+        logging = self.logging
+
+    self._metadata_directed = get_graph_metadata(graph=self, logging=logging)
+    if self._undirected is not None:
+        self._metadata_undirected = get_graph_metadata(
+            graph=self._undirected, logging=logging
+        )
+
+
+
+
+
+
+
+
+ +
+ + + diff --git a/docs/lang_main/analysis/index.html b/docs/lang_main/analysis/index.html new file mode 100644 index 0000000..c916f52 --- /dev/null +++ b/docs/lang_main/analysis/index.html @@ -0,0 +1,98 @@ + + + + + + +lang_main.analysis API documentation + + + + + + + + + + + +
+ + +
+ + + diff --git a/docs/lang_main/analysis/preprocessing.html b/docs/lang_main/analysis/preprocessing.html new file mode 100644 index 0000000..ebd5e22 --- /dev/null +++ b/docs/lang_main/analysis/preprocessing.html @@ -0,0 +1,451 @@ + + + + + + +lang_main.analysis.preprocessing API documentation + + + + + + + + + + + +
+
+
+

Module lang_main.analysis.preprocessing

+
+
+
+
+
+
+
+
+

Functions

+
+
+def analyse_feature(data: DataFrame, target_feature: str) ‑> tuple[pandas.core.frame.DataFrame] +
+
+
+ +Expand source code + +
def analyse_feature(
+    data: DataFrame,
+    target_feature: str,
+) -> tuple[DataFrame]:
+    # feature columns
+    feature_entries = data[target_feature]
+    logger.info(
+        'Number of entries for feature >>%s<<: %d', target_feature, len(feature_entries)
+    )
+    # obtain unique entries
+    unique_feature_entries = feature_entries.unique()
+
+    # prepare result DataFrame
+    cols = ['batched_idxs', 'entry', 'len', 'num_occur', 'assoc_obj_ids', 'num_assoc_obj_ids']
+    result_df = pd.DataFrame(columns=cols)
+
+    for entry in tqdm(unique_feature_entries, mininterval=1.0):
+        len_entry = len(entry)
+        filt = data[target_feature] == entry
+        temp = data[filt]
+        batched_idxs = temp.index.to_numpy()
+        assoc_obj_ids = temp['ObjektID'].unique()
+        assoc_obj_ids = np.sort(assoc_obj_ids, kind='stable')
+        num_assoc_obj_ids = len(assoc_obj_ids)
+        num_dupl = filt.sum()
+
+        conc_df = pd.DataFrame(
+            data=[
+                [batched_idxs, entry, len_entry, num_dupl, assoc_obj_ids, num_assoc_obj_ids]
+            ],
+            columns=cols,
+        )
+
+        result_df = pd.concat([result_df, conc_df], ignore_index=True)
+
+    result_df = result_df.sort_values(
+        by=['num_occur', 'len'], ascending=[False, False]
+    ).copy()
+
+    return (result_df,)
+
+
+
+
+def load_raw_data(path: Path,
date_cols: Collection[str] = ('VorgangsDatum', 'ErledigungsDatum', 'Arbeitsbeginn', 'ErstellungsDatum')) ‑> tuple[pandas.core.frame.DataFrame]
+
+
+
+ +Expand source code + +
def load_raw_data(
+    path: Path,
+    date_cols: Collection[str] = (
+        'VorgangsDatum',
+        'ErledigungsDatum',
+        'Arbeitsbeginn',
+        'ErstellungsDatum',
+    ),
+) -> tuple[DataFrame]:
+    """load IHM dataset with standard structure
+
+    Parameters
+    ----------
+    path : str
+        path to dataset file, usually CSV file
+    date_cols : Collection[str], optional
+        columns which contain dates and are parsed as such,
+        by default (
+            'VorgangsDatum',
+            'ErledigungsDatum',
+            'Arbeitsbeginn',
+            'ErstellungsDatum',
+        )
+
+    Returns
+    -------
+    DataFrame
+        raw dataset as DataFrame
+    """
+    # load dataset
+    date_cols = list(date_cols)
+    data = pd.read_csv(
+        filepath_or_buffer=path,
+        sep=';',
+        encoding='cp1252',
+        parse_dates=list(date_cols),
+        dayfirst=True,
+    )
+    logger.info('Loaded dataset successfully.')
+    logger.info(
+        (
+            f'Dataset properties: number of entries: {len(data)}, '
+            f'number of features {len(data.columns)}'
+        )
+    )
+    return (data,)
+
+

load IHM dataset with standard structure

+

Parameters

+
+
path : str
+
path to dataset file, usually CSV file
+
date_cols : Collection[str], optional
+
columns which contain dates and are parsed as such, +by default ( +'VorgangsDatum', +'ErledigungsDatum', +'Arbeitsbeginn', +'ErstellungsDatum', +)
+
+

Returns

+
+
DataFrame
+
raw dataset as DataFrame
+
+
+
+def merge_similarity_duplicates(data: DataFrame, model: SentenceTransformer, cos_sim_threshold: float) ‑> tuple[pandas.core.frame.DataFrame] +
+
+
+ +Expand source code + +
def merge_similarity_duplicates(
+    data: DataFrame,
+    model: SentenceTransformer,
+    cos_sim_threshold: float,
+) -> tuple[DataFrame]:
+    logger.info('Start merging of similarity candidates...')
+
+    # data
+    merged_data = data.copy()
+    model_input = merged_data['entry']
+    candidates_idx = candidates_by_index(
+        data_model_input=model_input,
+        model=model,
+        cos_sim_threshold=cos_sim_threshold,
+    )
+    # graph of similar ids
+    similar_id_graph, _ = similar_index_connection_graph(candidates_idx)
+
+    for similar_id_group in similar_index_groups(similar_id_graph):
+        similar_id_group = list(similar_id_group)
+        similar_data = merged_data.loc[similar_id_group, :]
+        # keep first entry with max number occurrences, then number of
+        # associated objects, then length of entry
+        similar_data = similar_data.sort_values(
+            by=['num_occur', 'num_assoc_obj_ids', 'len'],
+            ascending=[False, False, False],
+        )
+        # merge information to first entry
+        data_idx = cast(PandasIndex, similar_data.index[0])
+        similar_data.at[data_idx, 'num_occur'] = similar_data['num_occur'].sum()
+        assoc_obj_ids = similar_data['assoc_obj_ids'].to_numpy()
+        assoc_obj_ids = np.concatenate(assoc_obj_ids)
+        assoc_obj_ids = np.unique(assoc_obj_ids)
+        similar_data.at[data_idx, 'assoc_obj_ids'] = assoc_obj_ids
+        similar_data.at[data_idx, 'num_assoc_obj_ids'] = len(assoc_obj_ids)
+        # remaining indices, should be removed
+        similar_id_group.remove(data_idx)
+        merged_similar_data = similar_data.drop(index=similar_id_group)
+        # update entry in main dataset, drop remaining entries
+        merged_data.update(merged_similar_data)
+        merged_data = merged_data.drop(index=similar_id_group)
+
+    logger.info('Similarity candidates merged successfully.')
+
+    return (merged_data,)
+
+
+
+
+def numeric_pre_filter_feature(data: DataFrame, feature: str, bound_lower: int | None, bound_upper: int | None) ‑> tuple[pandas.core.frame.DataFrame] +
+
+
+ +Expand source code + +
def numeric_pre_filter_feature(
+    data: DataFrame,
+    feature: str,
+    bound_lower: int | None,
+    bound_upper: int | None,
+) -> tuple[DataFrame]:
+    """filter DataFrame for a given numerical feature regarding their bounds
+    bounds are inclusive: entries (bound_lower <= entry <= bound_upper) are retained
+
+    Parameters
+    ----------
+    data : DataFrame
+        DataFrame to filter
+    feature : str
+        feature name to filter
+    bound_lower : int | None
+        lower bound of values to retain
+    bound_upper : int | None
+        upper bound of values to retain
+
+    Returns
+    -------
+    tuple[DataFrame]
+        filtered DataFrame
+
+    Raises
+    ------
+    ValueError
+        if no bounds are provided, at least one bound must be set
+    """
+    if not any([bound_lower, bound_upper]):
+        raise ValueError('No bounds for filtering provided')
+
+    data = data.copy()
+    if bound_lower is None:
+        bound_lower = cast(int, data[feature].min())
+    if bound_upper is None:
+        bound_upper = cast(int, data[feature].max())
+
+    filter_lower = data[feature] >= bound_lower
+    filter_upper = data[feature] <= bound_upper
+    filter = filter_lower & filter_upper
+
+    data = data.loc[filter]
+
+    return (data,)
+
+

filter DataFrame for a given numerical feature regarding their bounds +bounds are inclusive: entries (bound_lower <= entry <= bound_upper) are retained

+

Parameters

+
+
data : DataFrame
+
DataFrame to filter
+
feature : str
+
feature name to filter
+
bound_lower : int | None
+
lower bound of values to retain
+
bound_upper : int | None
+
upper bound of values to retain
+
+

Returns

+
+
tuple[DataFrame]
+
filtered DataFrame
+
+

Raises

+
+
ValueError
+
if no bounds are provided, at least one bound must be set
+
+
+
+def remove_NA(data: DataFrame, target_features: Collection[str] = ('VorgangsBeschreibung',)) ‑> tuple[pandas.core.frame.DataFrame] +
+
+
+ +Expand source code + +
def remove_NA(
+    data: DataFrame,
+    target_features: Collection[str] = ('VorgangsBeschreibung',),
+) -> tuple[DataFrame]:
+    """function to drop NA entries based on a subset of features to be analysed
+
+    Parameters
+    ----------
+    data : DataFrame
+        standard IHM dataset, perhaps pre-cleaned
+    target_features : Collection[str], optional
+        subset to analyse to define an NA entry, by default ('VorgangsBeschreibung',)
+
+    Returns
+    -------
+    DataFrame
+        dataset with removed NA entries for given subset of features
+    """
+    target_features = list(target_features)
+    wo_NA = data.dropna(axis=0, subset=target_features, ignore_index=True).copy()  # type: ignore
+    logger.info(
+        f'Removed NA entries for features >>{target_features}<< from dataset successfully.'
+    )
+
+    return (wo_NA,)
+
+

function to drop NA entries based on a subset of features to be analysed

+

Parameters

+
+
data : DataFrame
+
standard IHM dataset, perhaps pre-cleaned
+
target_features : Collection[str], optional
+
subset to analyse to define an NA entry, by default ('VorgangsBeschreibung',)
+
+

Returns

+
+
DataFrame
+
dataset with removed NA entries for given subset of features
+
+
+
+def remove_duplicates(data: DataFrame) ‑> tuple[pandas.core.frame.DataFrame] +
+
+
+ +Expand source code + +
def remove_duplicates(
+    data: DataFrame,
+) -> tuple[DataFrame]:
+    """removes duplicated entries over all features in the given dataset
+
+    Parameters
+    ----------
+    data : DataFrame
+        read data with standard structure
+
+    Returns
+    -------
+    DataFrame
+        dataset with removed duplicates over all features
+    """
+    # obtain info about duplicates over all features
+    duplicates_filt = data.duplicated()
+    logger.info(f'Number of duplicates over all features: {duplicates_filt.sum()}')
+    # drop duplicates
+    wo_duplicates = data.drop_duplicates(ignore_index=True)
+    duplicates_subset: list[str] = [
+        'VorgangsID',
+        'ObjektID',
+    ]
+    duplicates_subset_filt = wo_duplicates.duplicated(subset=duplicates_subset)
+    logger.info(
+        (
+            'Number of duplicates over subset '
+            f'>>{duplicates_subset}<<: {duplicates_subset_filt.sum()}'
+        )
+    )
+    wo_duplicates = wo_duplicates.drop_duplicates(
+        subset=duplicates_subset, ignore_index=True
+    ).copy()
+    logger.info('Removed all duplicates from dataset successfully.')
+    logger.info(
+        'New Dataset properties: number of entries: %d, number of features %d',
+        len(wo_duplicates),
+        len(wo_duplicates.columns),
+    )
+
+    return (wo_duplicates,)
+
+

removes duplicated entries over all features in the given dataset

+

Parameters

+
+
data : DataFrame
+
read data with standard structure
+
+

Returns

+
+
DataFrame
+
dataset with removed duplicates over all features
+
+
+
+
+
+
+
+ +
+ + + diff --git a/docs/lang_main/analysis/shared.html b/docs/lang_main/analysis/shared.html new file mode 100644 index 0000000..6a1f95e --- /dev/null +++ b/docs/lang_main/analysis/shared.html @@ -0,0 +1,273 @@ + + + + + + +lang_main.analysis.shared API documentation + + + + + + + + + + + +
+
+
+

Module lang_main.analysis.shared

+
+
+
+
+
+
+
+
+

Functions

+
+
+def candidates_by_index(data_model_input: pandas.core.series.Series,
model: sentence_transformers.SentenceTransformer.SentenceTransformer,
cos_sim_threshold: float = 0.5) ‑> Iterator[tuple[int | numpy.int64, int | numpy.int64]]
+
+
+
+ +Expand source code + +
def candidates_by_index(
+    data_model_input: Series,
+    model: SentenceTransformer,
+    cos_sim_threshold: float = 0.5,
+) -> Iterator[tuple[PandasIndex, PandasIndex]]:
+    """function to filter candidate indices based on cosine similarity
+    using SentenceTransformer model in batch mode,
+    feed data as Series to retain information about indices of entries and
+    access them later in the original dataset
+
+    Parameters
+    ----------
+    obj_id : ObjectID
+        _description_
+    data_model_input : Series
+        containing indices and text entries to process
+    model : SentenceTransformer
+        necessary SentenceTransformer model to encode text entries
+    cos_sim_threshold : float, optional
+        threshold for cosine similarity to filter candidates, by default 0.5
+
+    Yields
+    ------
+    Iterator[tuple[PandasIndex, PandasIndex]]
+        tuple of index pairs which meet the cosine similarity threshold
+    """
+    # embeddings
+    batch = cast(list[str], data_model_input.to_list())
+    embds = cast(
+        Tensor,
+        model.encode(
+            batch,
+            convert_to_numpy=False,
+            convert_to_tensor=True,
+            show_progress_bar=False,
+        ),
+    )
+    # cosine similarity
+    cos_sim = cast(npt.NDArray, model.similarity(embds, embds).numpy())
+    np.fill_diagonal(cos_sim, 0.0)
+    cos_sim = np.triu(cos_sim)
+    cos_sim_idx = np.argwhere(cos_sim >= cos_sim_threshold)
+
+    for idx_array in cos_sim_idx:
+        idx_pair = cast(
+            tuple[np.int64, np.int64], tuple(data_model_input.index[idx] for idx in idx_array)
+        )
+        yield idx_pair
+
+

function to filter candidate indices based on cosine similarity +using SentenceTransformer model in batch mode, +feed data as Series to retain information about indices of entries and +access them later in the original dataset

+

Parameters

+
+
obj_id : ObjectID
+
description
+
data_model_input : Series
+
containing indices and text entries to process
+
model : SentenceTransformer
+
necessary SentenceTransformer model to encode text entries
+
cos_sim_threshold : float, optional
+
threshold for cosine similarity to filter candidates, by default 0.5
+
+

Yields

+
+
Iterator[tuple[PandasIndex, PandasIndex]]
+
tuple of index pairs which meet the cosine similarity threshold
+
+
+
+def clean_string_slim(string: str) ‑> str +
+
+
+ +Expand source code + +
def clean_string_slim(string: str) -> str:
+    """mapping function to clean single string entries in a series (feature-wise)
+    of the dataset, used to be applied element-wise for string features
+
+    Parameters
+    ----------
+    string : str
+        dataset entry feature
+
+    Returns
+    -------
+    str
+        cleaned entry
+    """
+    # remove special chars
+    # string = pattern_escape_newline.sub(' ', string)
+    string = pattern_escape_seq.sub(' ', string)
+    string = pattern_repeated_chars.sub('', string)
+    # string = pattern_dates.sub('', string)
+    # dates are used for context, should not be removed at this stage
+    string = pattern_whitespace.sub(' ', string)
+    # remove whitespaces at the beginning and the end
+    string = string.strip()
+
+    return string
+
+

mapping function to clean single string entries in a series (feature-wise) +of the dataset, used to be applied element-wise for string features

+

Parameters

+
+
string : str
+
dataset entry feature
+
+

Returns

+
+
str
+
cleaned entry
+
+
+
+def entry_wise_cleansing(data: pandas.core.frame.DataFrame,
target_features: Collection[str],
cleansing_func: Callable[[str], str] = <function clean_string_slim>) ‑> tuple[pandas.core.frame.DataFrame]
+
+
+
+ +Expand source code + +
def entry_wise_cleansing(
+    data: DataFrame,
+    target_features: Collection[str],
+    cleansing_func: Callable[[str], str] = clean_string_slim,
+) -> tuple[DataFrame]:
+    # apply given cleansing function to target feature
+    target_features = list(target_features)
+    data[target_features] = data[target_features].map(cleansing_func)
+    logger.info(
+        ('Successfully applied entry-wise cleansing procedure >>%s<< for features >>%s<<'),
+        cleansing_func.__name__,
+        target_features,
+    )
+    return (data,)
+
+
+
+
+def similar_index_connection_graph(similar_idx_pairs: Iterable[tuple[int | numpy.int64, int | numpy.int64]]) ‑> tuple[networkx.classes.graph.Graph, dict[str, float]] +
+
+
+ +Expand source code + +
def similar_index_connection_graph(
+    similar_idx_pairs: Iterable[tuple[PandasIndex, PandasIndex]],
+) -> tuple[Graph, dict[str, float]]:
+    # build index graph to obtain graph of connected (similar) indices
+    # use this graph to get connected components (indices which belong together)
+    # retain semantic connection on whole dataset
+    similar_id_graph = nx.Graph()
+    # for idx1, idx2 in similar_idx_pairs:
+    #     # inplace operation, parent/child do not really exist in undirected graph
+    #     update_graph(graph=similar_id_graph, parent=idx1, child=idx2)
+    update_graph(graph=similar_id_graph, batch=similar_idx_pairs)
+
+    graph_info = get_graph_metadata(graph=similar_id_graph, logging=False)
+
+    return similar_id_graph, graph_info
+
+
+
+
+def similar_index_groups(similar_id_graph: networkx.classes.graph.Graph) ‑> Iterator[tuple[int | numpy.int64, ...]] +
+
+
+ +Expand source code + +
def similar_index_groups(
+    similar_id_graph: Graph,
+) -> Iterator[tuple[PandasIndex, ...]]:
+    # groups of connected indices
+    ids_groups = cast(Iterator[set[PandasIndex]], nx.connected_components(G=similar_id_graph))
+
+    for id_group in ids_groups:
+        yield tuple(id_group)
+
+
+
+
+
+
+
+
+ +
+ + + diff --git a/docs/lang_main/analysis/timeline.html b/docs/lang_main/analysis/timeline.html new file mode 100644 index 0000000..13e2c56 --- /dev/null +++ b/docs/lang_main/analysis/timeline.html @@ -0,0 +1,333 @@ + + + + + + +lang_main.analysis.timeline API documentation + + + + + + + + + + + +
+
+
+

Module lang_main.analysis.timeline

+
+
+
+
+
+
+
+
+

Functions

+
+
+def calc_delta_to_next_failure(data: pandas.core.frame.DataFrame,
date_feature: str = 'ErstellungsDatum',
name_delta_feature: str = 'Zeitspanne bis zum nächsten Ereignis [Tage]',
convert_to_days: bool = True) ‑> pandas.core.frame.DataFrame
+
+
+
+ +Expand source code + +
def calc_delta_to_next_failure(
+    data: DataFrameTLFiltered,
+    date_feature: str = 'ErstellungsDatum',
+    name_delta_feature: str = NAME_DELTA_FEAT_TO_NEXT_FAILURE,
+    convert_to_days: bool = True,
+) -> DataFrameTLFiltered:
+    data = data.copy()
+    last_val = data[date_feature].iat[-1]
+    shifted = data[date_feature].shift(-1, fill_value=last_val)
+    data[name_delta_feature] = shifted - data[date_feature]
+    data = data.sort_values(by=name_delta_feature, ascending=False)
+
+    if convert_to_days:
+        data[name_delta_feature] = data[name_delta_feature].dt.days
+
+    return data
+
+
+
+
+def calc_delta_to_repair(data: pandas.core.frame.DataFrame,
date_feature_start: str = 'ErstellungsDatum',
date_feature_end: str = 'ErledigungsDatum',
name_delta_feature: str = 'Zeitspanne bis zur Behebung [Tage]',
convert_to_days: bool = True) ‑> tuple[pandas.core.frame.DataFrame]
+
+
+
+ +Expand source code + +
def calc_delta_to_repair(
+    data: DataFrame,
+    date_feature_start: str = 'ErstellungsDatum',
+    date_feature_end: str = 'ErledigungsDatum',
+    name_delta_feature: str = NAME_DELTA_FEAT_TO_REPAIR,
+    convert_to_days: bool = True,
+) -> tuple[DataFrame]:
+    logger.info('Calculating time differences between start and end of operations...')
+    data = data.copy()
+    data[name_delta_feature] = data[date_feature_end] - data[date_feature_start]
+
+    if convert_to_days:
+        data[name_delta_feature] = data[name_delta_feature].dt.days
+
+    logger.info('Calculation successful.')
+
+    return (data,)
+
+
+
+
+def cleanup_descriptions(data: pandas.core.frame.DataFrame,
properties: Collection[str] = ('VorgangsBeschreibung', 'ErledigungsBeschreibung')) ‑> tuple[pandas.core.frame.DataFrame]
+
+
+
+ +Expand source code + +
def cleanup_descriptions(
+    data: DataFrame,
+    properties: Collection[str] = (
+        'VorgangsBeschreibung',
+        'ErledigungsBeschreibung',
+    ),
+) -> tuple[DataFrame]:
+    logger.info('Cleaning necessary descriptions...')
+    data = data.copy()
+    features = list(properties)
+    data[features] = data[features].fillna('N.V.')
+    (data,) = entry_wise_cleansing(data, target_features=features)
+    logger.info('Cleansing successful.')
+
+    return (data.copy(),)
+
+
+
+
+def filter_activities_per_obj_id(data: pandas.core.frame.DataFrame,
activity_feature: str = 'VorgangsTypName',
relevant_activity_types: Iterable[str] = ('Reparaturauftrag (Portal)',),
feature_obj_id: str = 'ObjektID',
threshold_num_activities: int = 1) ‑> tuple[pandas.core.frame.DataFrame, pandas.core.series.Series]
+
+
+
+ +Expand source code + +
def filter_activities_per_obj_id(
+    data: DataFrame,
+    activity_feature: str = 'VorgangsTypName',
+    relevant_activity_types: Iterable[str] = ('Reparaturauftrag (Portal)',),
+    feature_obj_id: str = 'ObjektID',
+    threshold_num_activities: int = 1,
+) -> tuple[DataFrame, Series]:
+    data = data.copy()
+    # filter only relevant activities, count occurrences for each ObjectID
+    logger.info('Filtering activities per ObjectID...')
+    filt_rel_activities = data[activity_feature].isin(relevant_activity_types)
+    data_filter_activities = data.loc[filt_rel_activities].copy()
+    num_activities_per_obj_id = cast(
+        Series, data_filter_activities[feature_obj_id].value_counts(sort=True)
+    )
+    # filter for ObjectIDs with more than given number of activities
+    filt_below_thresh = num_activities_per_obj_id <= threshold_num_activities
+    # index of series contains ObjectIDs
+    obj_ids_below_thresh = num_activities_per_obj_id[filt_below_thresh].index
+    filt_entries_below_thresh = data_filter_activities[feature_obj_id].isin(
+        obj_ids_below_thresh
+    )
+
+    num_activities_per_obj_id = num_activities_per_obj_id.loc[~filt_below_thresh]
+    data_filter_activities = data_filter_activities.loc[~filt_entries_below_thresh]
+    logger.info('Activities per ObjectID filtered successfully.')
+
+    return data_filter_activities, num_activities_per_obj_id
+
+
+
+
+def filter_timeline_cands(data: pandas.core.frame.DataFrame,
cands: dict[int, tuple[tuple[int | numpy.int64, ...], ...]],
obj_id: int,
entry_idx: int,
sort_feature: str = 'ErstellungsDatum') ‑> pandas.core.frame.DataFrame
+
+
+
+ +Expand source code + +
def filter_timeline_cands(
+    data: DataFrame,
+    cands: TimelineCandidates,
+    obj_id: ObjectID,
+    entry_idx: int,
+    sort_feature: str = 'ErstellungsDatum',
+) -> DataFrameTLFiltered:
+    data = data.copy()
+    cands_for_obj_id = cands[obj_id]
+    cands_choice = cands_for_obj_id[entry_idx]
+    data = data.loc[list(cands_choice)].sort_values(
+        by=sort_feature,
+        ascending=True,
+    )
+
+    return data
+
+
+
+
+def generate_model_input(data: pandas.core.frame.DataFrame,
target_feature_name: str = 'nlp_model_input',
model_input_features: Iterable[str] = ('VorgangsTypName', 'VorgangsArtText', 'VorgangsBeschreibung')) ‑> tuple[pandas.core.frame.DataFrame]
+
+
+
+ +Expand source code + +
def generate_model_input(
+    data: DataFrame,
+    target_feature_name: str = 'nlp_model_input',
+    model_input_features: Iterable[str] = (
+        'VorgangsTypName',
+        'VorgangsArtText',
+        'VorgangsBeschreibung',
+    ),
+) -> tuple[DataFrame]:
+    logger.info('Generating concatenation of model input features...')
+    data = data.copy()
+    model_input_features = list(model_input_features)
+    input_features = data[model_input_features].fillna('').astype(str)
+    data[target_feature_name] = input_features.apply(
+        lambda x: ' - '.join(x),
+        axis=1,
+    )
+    logger.info('Model input generated successfully.')
+
+    return (data,)
+
+
+
+
+def get_timeline_candidates(data: pandas.core.frame.DataFrame,
num_activities_per_obj_id: pandas.core.series.Series,
*,
model: sentence_transformers.SentenceTransformer.SentenceTransformer,
cos_sim_threshold: float,
feature_obj_id: str = 'ObjektID',
feature_obj_text: str = 'HObjektText',
model_input_feature: str = 'nlp_model_input') ‑> tuple[dict[int, tuple[tuple[int | numpy.int64, ...], ...]], dict[int, str]]
+
+
+
+ +Expand source code + +
def get_timeline_candidates(
+    data: DataFrame,
+    num_activities_per_obj_id: Series,
+    *,
+    model: SentenceTransformer,
+    cos_sim_threshold: float,
+    feature_obj_id: str = 'ObjektID',
+    feature_obj_text: str = 'HObjektText',
+    model_input_feature: str = 'nlp_model_input',
+) -> tuple[TimelineCandidates, dict[ObjectID, str]]:
+    logger.info('Obtaining timeline candidates...')
+    candidates = _get_timeline_candidates_index(
+        data=data,
+        num_activities_per_obj_id=num_activities_per_obj_id,
+        model=model,
+        cos_sim_threshold=cos_sim_threshold,
+        feature_obj_id=feature_obj_id,
+        model_input_feature=model_input_feature,
+    )
+    tl_candidates = _transform_timeline_candidates(candidates)
+    logger.info('Timeline candidates obtained successfully.')
+    # text mapping to obtain object descriptors
+    logger.info('Mapping ObjectIDs to their respective text descriptor...')
+    map_obj_text = _map_obj_id_to_texts(
+        data=data,
+        feature_obj_id=feature_obj_id,
+        feature_obj_text=feature_obj_text,
+    )
+    logger.info('ObjectIDs successfully mapped to text descriptors.')
+
+    return tl_candidates, map_obj_text
+
+
+
+
+def remove_non_relevant_obj_ids(data: pandas.core.frame.DataFrame,
thresh_unique_feat_per_id: int,
*,
feature_uniqueness: str = 'HObjektText',
feature_obj_id: str = 'ObjektID') ‑> tuple[pandas.core.frame.DataFrame]
+
+
+
+ +Expand source code + +
def remove_non_relevant_obj_ids(
+    data: DataFrame,
+    thresh_unique_feat_per_id: int,
+    *,
+    feature_uniqueness: str = 'HObjektText',
+    feature_obj_id: str = 'ObjektID',
+) -> tuple[DataFrame]:
+    logger.info('Removing non-relevant ObjectIDs from dataset...')
+    data = data.copy()
+    ids_to_ignore = _non_relevant_obj_ids(
+        data=data,
+        thresh_unique_feat_per_id=thresh_unique_feat_per_id,
+        feature_uniqueness=feature_uniqueness,
+        feature_obj_id=feature_obj_id,
+    )
+    # only retain entries with ObjectIDs not in IDs to ignore
+    data = data.loc[~(data[feature_obj_id].isin(ids_to_ignore))]
+    logger.debug('Ignored ObjectIDs: %s', ids_to_ignore)
+    logger.info('Non-relevant ObjectIDs removed successfully.')
+
+    return (data,)
+
+
+
+
+
+
+
+
+ +
+ + + diff --git a/docs/lang_main/analysis/tokens.html b/docs/lang_main/analysis/tokens.html new file mode 100644 index 0000000..4746df0 --- /dev/null +++ b/docs/lang_main/analysis/tokens.html @@ -0,0 +1,320 @@ + + + + + + +lang_main.analysis.tokens API documentation + + + + + + + + + + + +
+
+
+

Module lang_main.analysis.tokens

+
+
+
+
+
+
+
+
+

Functions

+
+
+def add_doc_info_to_graph(graph: TokenGraph,
doc: spacy.tokens.doc.Doc,
weight: int | None) ‑> None
+
+
+
+ +Expand source code + +
def add_doc_info_to_graph(
+    graph: TokenGraph,
+    doc: SpacyDoc,
+    weight: int | None,
+) -> None:
+    # iterate over sentences
+    for sent in doc.sents:
+        # iterate over tokens in sentence
+        for token in sent:
+            # skip tokens which are not relevant
+            if not (token.pos_ in POS_OF_INTEREST or token.tag_ in TAG_OF_INTEREST):
+                continue
+            # skip token which are dates or times
+            if token.pos_ == 'NUM' and is_str_date(string=token.text):
+                continue
+
+            relevant_descendants = obtain_relevant_descendants(token=token)
+            # for non-AUX: add parent <--> descendant pair to graph
+            if token.pos_ not in POS_INDIRECT:
+                for descendant in relevant_descendants:
+                    # add descendant and parent to graph
+                    update_graph(
+                        graph=graph,
+                        parent=token.lemma_,
+                        child=descendant.lemma_,
+                        weight_connection=weight,
+                    )
+            else:
+                # if indirect POS, make connection between all associated words
+                combs = combinations(relevant_descendants, r=2)
+                for comb in combs:
+                    # !! parents and children do not really exist in this case,
+                    # !! but only one connection is made
+                    update_graph(
+                        graph=graph,
+                        parent=comb[0].lemma_,
+                        child=comb[1].lemma_,
+                        weight_connection=weight,
+                    )
+
+
+
+
+def build_token_graph(data: pandas.core.frame.DataFrame,
model: spacy.language.Language,
*,
target_feature: str = 'entry',
weights_feature: str | None = None,
batch_idx_feature: str | None = 'batched_idxs',
build_map: bool = True,
batch_size_model: int = 50,
logging_graph: bool = True) ‑> tuple[TokenGraph, dict[int | numpy.int64, spacy.tokens.doc.Doc] | None]
+
+
+
+ +Expand source code + +
def build_token_graph(
+    data: DataFrame,
+    model: SpacyModel,
+    *,
+    target_feature: str = 'entry',
+    weights_feature: str | None = None,
+    batch_idx_feature: str | None = 'batched_idxs',
+    build_map: bool = True,
+    batch_size_model: int = 50,
+    logging_graph: bool = True,
+) -> tuple[TokenGraph, dict[PandasIndex, SpacyDoc] | None]:
+    graph = TokenGraph(enable_logging=logging_graph)
+    model_input = cast(tuple[str], tuple(data[target_feature].to_list()))
+    if weights_feature is not None:
+        weights = cast(tuple[int], tuple(data[weights_feature].to_list()))
+    else:
+        weights = None
+
+    docs_mapping: dict[PandasIndex, SpacyDoc] | None
+    if build_map and batch_idx_feature is None:
+        raise ValueError('Can not build mapping if batched indices are unknown.')
+    elif build_map:
+        indices = cast(tuple[list[PandasIndex]], tuple(data[batch_idx_feature].to_list()))
+        docs_mapping = {}
+    else:
+        indices = None
+        docs_mapping = None
+
+    index: int = 0
+
+    for doc in tqdm(
+        model.pipe(model_input, batch_size=batch_size_model), total=len(model_input)
+    ):
+        weight: int | None = None
+        if weights is not None:
+            weight = weights[index]
+
+        add_doc_info_to_graph(
+            graph=graph,
+            doc=doc,
+            weight=weight,
+        )
+        # build map if option chosen
+        if indices is not None and docs_mapping is not None:
+            corresponding_indices = indices[index]
+            for idx in corresponding_indices:
+                docs_mapping[idx] = doc
+
+        index += 1
+
+    # metadata
+    graph.update_metadata()
+    # convert to undirected
+    graph.to_undirected(logging=False)
+    graph.perform_static_analysis()
+
+    return graph, docs_mapping
+
+
+
+
+def is_str_date(string: str, fuzzy: bool = False) ‑> bool +
+
+
+ +Expand source code + +
def is_str_date(
+    string: str,
+    fuzzy: bool = False,
+) -> bool:
+    """not stable function to test strings for dates, not 100 percent reliable
+
+    Parameters
+    ----------
+    string : str
+        string to check for dates
+    fuzzy : bool, optional
+        whether to use dateutils.parser.pase fuzzy capability, by default False
+
+    Returns
+    -------
+    bool
+        indicates whether date was found or not
+    """
+    try:
+        # check if string is a number
+        # if length is greater than 8, it is not a date
+        int(string)
+        if len(string) not in {2, 4}:
+            return False
+    except ValueError:
+        # not a number
+        pass
+
+    try:
+        parse(string, fuzzy=fuzzy, dayfirst=True, yearfirst=False)
+        return True
+    except ValueError:
+        date_found: bool = False
+        match = pattern_dates.search(string)
+        if match is None:
+            return date_found
+        date_found = any(match.groups())
+        return date_found
+
+

not stable function to test strings for dates, not 100 percent reliable

+

Parameters

+
+
string : str
+
string to check for dates
+
fuzzy : bool, optional
+
whether to use dateutils.parser.pase fuzzy capability, by default False
+
+

Returns

+
+
bool
+
indicates whether date was found or not
+
+
+
+def obtain_relevant_descendants(token: spacy.tokens.token.Token) ‑> Iterator[spacy.tokens.token.Token] +
+
+
+ +Expand source code + +
def obtain_relevant_descendants(
+    token: SpacyToken,
+) -> Iterator[SpacyToken]:
+    for descendant in token.subtree:
+        # subtrees contain the token itself
+        # if current element is token skip this element
+        if descendant == token:
+            continue
+
+        # if descendant is a date skip it)
+        if is_str_date(string=descendant.text):
+            continue
+
+        logger.debug(
+            'Token >>%s<<, POS >>%s<< | descendant >>%s<<, POS >>%s<<',
+            token,
+            token.pos_,
+            descendant,
+            descendant.pos_,
+        )
+
+        # eliminate cases of cross-references with verbs
+        if (token.pos_ == 'AUX' or token.pos_ == 'VERB') and (
+            descendant.pos_ == 'AUX' or descendant.pos_ == 'VERB'
+        ):
+            continue
+        # skip cases in which descendant is indirect POS with others than verbs
+        elif descendant.pos_ in POS_INDIRECT:
+            continue
+        # skip cases in which child has no relevant POS or TAG
+        elif not (descendant.pos_ in POS_OF_INTEREST or descendant.tag_ in TAG_OF_INTEREST):
+            continue
+
+        yield descendant
+
+        # TODO look at results and fine-tune function accordingly
+
+
+
+
+def pre_clean_word(string: str) ‑> str +
+
+
+ +Expand source code + +
def pre_clean_word(string: str) -> str:
+    pattern = r'[^A-Za-zäöüÄÖÜ]+'
+    string = re.sub(pattern, '', string)
+
+    return string
+
+
+
+
+
+
+
+
+ +
+ + + diff --git a/docs/lang_main/config.html b/docs/lang_main/config.html new file mode 100644 index 0000000..230f43f --- /dev/null +++ b/docs/lang_main/config.html @@ -0,0 +1,206 @@ + + + + + + +lang_main.config API documentation + + + + + + + + + + + +
+
+
+

Module lang_main.config

+
+
+
+
+
+
+
+
+

Functions

+
+
+def get_config_paths(root_folder: Path, cfg_name: str, cyto_stylesheet_name: str) ‑> tuple[pathlib.Path, pathlib.Path] +
+
+
+ +Expand source code + +
def get_config_paths(
+    root_folder: Path,
+    cfg_name: str,
+    cyto_stylesheet_name: str,
+) -> tuple[Path, Path]:
+    cfg_path_internal = (root_folder / cfg_name).resolve()
+    cyto_stylesheet_path = (root_folder / cyto_stylesheet_name).resolve()
+
+    return cfg_path_internal, cyto_stylesheet_path
+
+
+
+
+def load_cfg(starting_path: Path,
glob_pattern: str,
stop_folder_name: str | None,
lookup_cwd: bool = False) ‑> dict[str, typing.Any]
+
+
+
+ +Expand source code + +
def load_cfg(
+    starting_path: Path,
+    glob_pattern: str,
+    stop_folder_name: str | None,
+    lookup_cwd: bool = False,
+) -> dict[str, Any]:
+    """Look for configuration file. Internal configs are not used any more because
+    the library behaviour is only guaranteed by external configurations.
+
+    Parameters
+    ----------
+    starting_path : Path
+        path to start for the lookup
+    glob_pattern : str
+        pattern of the config file naming scheme
+    stop_folder_name : str | None
+        folder name at which the lookup should stop, the parent folder
+        is also searched, e.g.
+        if starting_path is path/to/start/folder and stop_folder_name is 'to',
+        then path/ is also searched
+
+    Returns
+    -------
+    dict[str, Any]
+        loaded config file
+
+    Raises
+    ------
+    LangMainConfigNotFoundError
+        if no config file was found
+    """
+    cfg_path: Path | None = None
+    if lookup_cwd:
+        print('Looking for cfg file in CWD.', flush=True)
+        cfg_path = search_cwd(glob_pattern)
+
+    if cfg_path is None:
+        print(
+            (
+                f'Looking iteratively for config file. Start: {starting_path}, '
+                f'stop folder: {stop_folder_name}'
+            ),
+            flush=True,
+        )
+        cfg_path = search_iterative(
+            starting_path=starting_path,
+            glob_pattern=glob_pattern,
+            stop_folder_name=stop_folder_name,
+        )
+
+    if cfg_path is None:
+        raise LangMainConfigNotFoundError('Config file was not found.')
+
+    config = load_toml_config(path_to_toml=cfg_path)
+    print(f'Loaded config from: >>{cfg_path}<<')
+
+    return config.copy()
+
+

Look for configuration file. Internal configs are not used any more because +the library behaviour is only guaranteed by external configurations.

+

Parameters

+
+
starting_path : Path
+
path to start for the lookup
+
glob_pattern : str
+
pattern of the config file naming scheme
+
stop_folder_name : str | None
+
folder name at which the lookup should stop, the parent folder +is also searched, e.g. +if starting_path is path/to/start/folder and stop_folder_name is 'to', +then path/ is also searched
+
+

Returns

+
+
dict[str, Any]
+
loaded config file
+
+

Raises

+
+
LangMainConfigNotFoundError
+
if no config file was found
+
+
+
+def load_toml_config(path_to_toml: str | Path) ‑> dict[str, typing.Any] +
+
+
+ +Expand source code + +
def load_toml_config(
+    path_to_toml: str | Path,
+) -> dict[str, Any]:
+    with open(path_to_toml, 'rb') as f:
+        data = tomllib.load(f)
+    print('Loaded TOML config file successfully.', flush=True)
+
+    return data
+
+
+
+
+
+
+
+
+ +
+ + + diff --git a/docs/lang_main/constants.html b/docs/lang_main/constants.html new file mode 100644 index 0000000..79ede97 --- /dev/null +++ b/docs/lang_main/constants.html @@ -0,0 +1,66 @@ + + + + + + +lang_main.constants API documentation + + + + + + + + + + + +
+
+
+

Module lang_main.constants

+
+
+
+
+
+
+
+
+
+
+
+
+ +
+ + + diff --git a/docs/lang_main/errors.html b/docs/lang_main/errors.html new file mode 100644 index 0000000..a713ec5 --- /dev/null +++ b/docs/lang_main/errors.html @@ -0,0 +1,330 @@ + + + + + + +lang_main.errors API documentation + + + + + + + + + + + +
+
+
+

Module lang_main.errors

+
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class DependencyMissingError +(*args, **kwargs) +
+
+
+ +Expand source code + +
class DependencyMissingError(Exception):
+    """Error raised if needed dependency could not be found"""
+
+

Error raised if needed dependency could not be found

+

Ancestors

+
    +
  • builtins.Exception
  • +
  • builtins.BaseException
  • +
+
+
+class EdgePropertyNotContainedError +(*args, **kwargs) +
+
+
+ +Expand source code + +
class EdgePropertyNotContainedError(Exception):
+    """Error raised if a needed edge property is not contained in graph edges"""
+
+

Error raised if a needed edge property is not contained in graph edges

+

Ancestors

+
    +
  • builtins.Exception
  • +
  • builtins.BaseException
  • +
+
+
+class EmptyEdgesError +(*args, **kwargs) +
+
+
+ +Expand source code + +
class EmptyEdgesError(EmptyGraphError):
+    """Error raised if action should be performed on a graph's edges, but
+    it does not contain any"""
+
+

Error raised if action should be performed on a graph's edges, but +it does not contain any

+

Ancestors

+ +
+
+class EmptyGraphError +(*args, **kwargs) +
+
+
+ +Expand source code + +
class EmptyGraphError(Exception):
+    """Error raised if an operation should be performed on the graph,
+    but it does not contain any nodes or edges"""
+
+

Error raised if an operation should be performed on the graph, +but it does not contain any nodes or edges

+

Ancestors

+
    +
  • builtins.Exception
  • +
  • builtins.BaseException
  • +
+

Subclasses

+ +
+
+class GraphRenderError +(*args, **kwargs) +
+
+
+ +Expand source code + +
class GraphRenderError(Exception):
+    """Error raised if a graph object can not be rendered"""
+
+

Error raised if a graph object can not be rendered

+

Ancestors

+
    +
  • builtins.Exception
  • +
  • builtins.BaseException
  • +
+
+
+class LangMainConfigNotFoundError +(*args, **kwargs) +
+
+
+ +Expand source code + +
class LangMainConfigNotFoundError(Exception):
+    """Error raised if a config file could not be found successfully"""
+
+

Error raised if a config file could not be found successfully

+

Ancestors

+
    +
  • builtins.Exception
  • +
  • builtins.BaseException
  • +
+
+
+class LanguageModelNotFoundError +(*args, **kwargs) +
+
+
+ +Expand source code + +
class LanguageModelNotFoundError(Exception):
+    """Error raised if a given language model could not be loaded successfully"""
+
+

Error raised if a given language model could not be loaded successfully

+

Ancestors

+
    +
  • builtins.Exception
  • +
  • builtins.BaseException
  • +
+
+
+class NoPerformableActionError +(*args, **kwargs) +
+
+
+ +Expand source code + +
class NoPerformableActionError(Exception):
+    """Error describing that no action is available in the current pipeline"""
+
+

Error describing that no action is available in the current pipeline

+

Ancestors

+
    +
  • builtins.Exception
  • +
  • builtins.BaseException
  • +
+
+
+class NodePropertyNotContainedError +(*args, **kwargs) +
+
+
+ +Expand source code + +
class NodePropertyNotContainedError(Exception):
+    """Error raised if a needed node property is not contained in graph edges"""
+
+

Error raised if a needed node property is not contained in graph edges

+

Ancestors

+
    +
  • builtins.Exception
  • +
  • builtins.BaseException
  • +
+
+
+class OutputInPipelineContainerError +(*args, **kwargs) +
+
+
+ +Expand source code + +
class OutputInPipelineContainerError(Exception):
+    """Error raised if an output was detected by one of the performed
+    actions in a PipelineContainer. Each action in a PipelineContainer is itself a
+    procedure which does not have any parameters or return values and should therefore not
+    return any values."""
+
+

Error raised if an output was detected by one of the performed +actions in a PipelineContainer. Each action in a PipelineContainer is itself a +procedure which does not have any parameters or return values and should therefore not +return any values.

+

Ancestors

+
    +
  • builtins.Exception
  • +
  • builtins.BaseException
  • +
+
+
+class WrongActionTypeError +(*args, **kwargs) +
+
+
+ +Expand source code + +
class WrongActionTypeError(Exception):
+    """Error raised if added action type is not supported by corresponding pipeline"""
+
+

Error raised if added action type is not supported by corresponding pipeline

+

Ancestors

+
    +
  • builtins.Exception
  • +
  • builtins.BaseException
  • +
+
+
+
+
+ +
+ + + diff --git a/docs/lang_main/index.html b/docs/lang_main/index.html new file mode 100644 index 0000000..1688465 --- /dev/null +++ b/docs/lang_main/index.html @@ -0,0 +1,123 @@ + + + + + + +lang_main API documentation + + + + + + + + + + + +
+ + +
+ + + diff --git a/docs/lang_main/io.html b/docs/lang_main/io.html new file mode 100644 index 0000000..2a5c223 --- /dev/null +++ b/docs/lang_main/io.html @@ -0,0 +1,227 @@ + + + + + + +lang_main.io API documentation + + + + + + + + + + + +
+
+
+

Module lang_main.io

+
+
+
+
+
+
+
+
+

Functions

+
+
+def create_saving_folder(saving_path_folder: str | pathlib.Path, overwrite_existing: bool = False) ‑> None +
+
+
+ +Expand source code + +
def create_saving_folder(
+    saving_path_folder: str | Path,
+    overwrite_existing: bool = False,
+) -> None:
+    # check for existence of given path
+    if isinstance(saving_path_folder, str):
+        saving_path_folder = Path(saving_path_folder)
+    if not saving_path_folder.exists():
+        saving_path_folder.mkdir(parents=True)
+    else:
+        if overwrite_existing:
+            # overwrite if desired (deletes whole path and re-creates it)
+            shutil.rmtree(saving_path_folder)
+            saving_path_folder.mkdir(parents=True)
+        else:
+            logger.info(
+                (
+                    'Path >>%s<< already exists and remained unchanged. If you want to '
+                    'overwrite this path, use parameter >>overwrite_existing<<.',
+                ),
+                saving_path_folder,
+            )
+
+
+
+
+def decode_from_base64_str(b64_str: str, encoding: str = 'utf-8') ‑> Any +
+
+
+ +Expand source code + +
def decode_from_base64_str(
+    b64_str: str,
+    encoding: str = 'utf-8',
+) -> Any:
+    b64_bytes = b64_str.encode(encoding=encoding)
+    decoded = base64.b64decode(b64_bytes)
+    return pickle.loads(decoded)
+
+
+
+
+def encode_file_to_base64_str(path: pathlib.Path, encoding: str = 'utf-8') ‑> str +
+
+
+ +Expand source code + +
def encode_file_to_base64_str(
+    path: Path,
+    encoding: str = 'utf-8',
+) -> str:
+    with open(path, 'rb') as file:
+        b64_bytes = base64.b64encode(file.read())
+    return b64_bytes.decode(encoding=encoding)
+
+
+
+
+def encode_to_base64_str(obj: Any, encoding: str = 'utf-8') ‑> str +
+
+
+ +Expand source code + +
def encode_to_base64_str(
+    obj: Any,
+    encoding: str = 'utf-8',
+) -> str:
+    serialised = pickle.dumps(obj, protocol=PICKLE_PROTOCOL_VERSION)
+    b64_bytes = base64.b64encode(serialised)
+    return b64_bytes.decode(encoding=encoding)
+
+
+
+
+def get_entry_point(saving_path: pathlib.Path,
filename: str,
file_ext: str = '.pkl',
check_existence: bool = True) ‑> pathlib.Path
+
+
+
+ +Expand source code + +
def get_entry_point(
+    saving_path: Path,
+    filename: str,
+    file_ext: str = '.pkl',
+    check_existence: bool = True,
+) -> Path:
+    entry_point_path = (saving_path / filename).with_suffix(file_ext)
+    if check_existence and not entry_point_path.exists():
+        raise FileNotFoundError(
+            f'Could not find provided entry data under path: >>{entry_point_path}<<'
+        )
+
+    return entry_point_path
+
+
+
+
+def load_pickle(path: str | pathlib.Path) ‑> Any +
+
+
+ +Expand source code + +
def load_pickle(
+    path: str | Path,
+) -> Any:
+    with open(path, 'rb') as file:
+        obj = pickle.load(file)
+    logger.info('Loaded file successfully.')
+    return obj
+
+
+
+
+def save_pickle(obj: Any, path: str | pathlib.Path) ‑> None +
+
+
+ +Expand source code + +
def save_pickle(
+    obj: Any,
+    path: str | Path,
+) -> None:
+    with open(path, 'wb') as file:
+        pickle.dump(obj, file, protocol=PICKLE_PROTOCOL_VERSION)
+    logger.info('Saved file successfully under %s', path)
+
+
+
+
+
+
+
+
+ +
+ + + diff --git a/docs/lang_main/loggers.html b/docs/lang_main/loggers.html new file mode 100644 index 0000000..0152ee3 --- /dev/null +++ b/docs/lang_main/loggers.html @@ -0,0 +1,66 @@ + + + + + + +lang_main.loggers API documentation + + + + + + + + + + + +
+
+
+

Module lang_main.loggers

+
+
+
+
+
+
+
+
+
+
+
+
+ +
+ + + diff --git a/docs/lang_main/model_loader.html b/docs/lang_main/model_loader.html new file mode 100644 index 0000000..c354e5e --- /dev/null +++ b/docs/lang_main/model_loader.html @@ -0,0 +1,162 @@ + + + + + + +lang_main.model_loader API documentation + + + + + + + + + + + +
+
+
+

Module lang_main.model_loader

+
+
+
+
+
+
+
+
+

Functions

+
+
+def instantiate_model(model_load_map: ModelLoaderMap, model: LanguageModels) ‑> sentence_transformers.SentenceTransformer.SentenceTransformer | spacy.language.Language +
+
+
+ +Expand source code + +
def instantiate_model(
+    model_load_map: ModelLoaderMap,
+    model: LanguageModels,
+) -> Model:
+    if model not in model_load_map:
+        raise KeyError(f'Model >>{model}<< not known. Choose from: {model_load_map.keys()}')
+    builder_func = model_load_map[model]['func']
+    func_kwargs = model_load_map[model]['kwargs']
+
+    return builder_func(**func_kwargs)
+
+
+
+
+def load_sentence_transformer(model_name: STFRModelTypes | str,
similarity_func: SimilarityFunction = SimilarityFunction.COSINE,
backend: STFRBackends = torch,
device: STFRDeviceTypes = cpu,
local_files_only: bool = True,
trust_remote_code: bool = False,
model_save_folder: str | None = None,
model_kwargs: STFRModelArgs | dict[str, Any] | None = None,
force_download: bool = False) ‑> sentence_transformers.SentenceTransformer.SentenceTransformer
+
+
+
+ +Expand source code + +
def load_sentence_transformer(
+    model_name: STFRModelTypes | str,
+    similarity_func: SimilarityFunction = SimilarityFunction.COSINE,
+    backend: STFRBackends = STFRBackends.TORCH,
+    device: STFRDeviceTypes = STFRDeviceTypes.CPU,
+    local_files_only: bool = True,
+    trust_remote_code: bool = False,
+    model_save_folder: str | None = None,
+    model_kwargs: STFRModelArgs | dict[str, Any] | None = None,
+    force_download: bool = False,
+) -> SentenceTransformer:
+    model_name_or_path = _preprocess_STFR_model_name(
+        model_name=model_name, backend=backend, force_download=force_download
+    )
+    model = SentenceTransformer(
+        model_name_or_path=model_name_or_path,
+        similarity_fn_name=similarity_func,
+        backend=backend,  # type: ignore Literal matches Enum
+        device=device,
+        cache_folder=model_save_folder,
+        local_files_only=local_files_only,
+        trust_remote_code=trust_remote_code,
+        model_kwargs=model_kwargs,  # type: ignore
+    )
+    logger.info('[MODEL LOADING] Loaded model >>%s<< successfully', model_name)
+
+    return model
+
+
+
+
+def load_spacy(model_name: str) ‑> spacy.language.Language +
+
+
+ +Expand source code + +
def load_spacy(
+    model_name: str,
+) -> SpacyModel:
+    try:
+        spacy_model_obj = importlib.import_module(model_name)
+    except ModuleNotFoundError:
+        raise LanguageModelNotFoundError(
+            (
+                f'Could not find spaCy model >>{model_name}<<. '
+                f'Check if it is installed correctly.'
+            )
+        )
+    pretrained_model = cast(SpacyModel, spacy_model_obj.load())
+    logger.info('[MODEL LOADING] Loaded model >>%s<< successfully', model_name)
+
+    return pretrained_model
+
+
+
+
+
+
+
+
+ +
+ + + diff --git a/docs/lang_main/pipelines/base.html b/docs/lang_main/pipelines/base.html new file mode 100644 index 0000000..a95bb8f --- /dev/null +++ b/docs/lang_main/pipelines/base.html @@ -0,0 +1,755 @@ + + + + + + +lang_main.pipelines.base API documentation + + + + + + + + + + + +
+
+
+

Module lang_main.pipelines.base

+
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class BasePipeline +(name: str, working_dir: Path) +
+
+
+ +Expand source code + +
class BasePipeline(ABC):
+    def __init__(
+        self,
+        name: str,
+        working_dir: Path,
+    ) -> None:
+        # init base class
+        super().__init__()
+
+        # name of pipeline
+        self.name = name
+        # working directory for pipeline == output path
+        self.working_dir = working_dir
+
+        # container for actions to perform during pass
+        self.actions: list[Callable] = []
+        self.action_names: list[str] = []
+        self.action_skip: list[bool] = []
+        # progress tracking, start at 1
+        self.curr_proc_idx: int = 1
+
+    def __repr__(self) -> str:
+        return (
+            f'{self.__class__.__name__}(name: {self.name}, '
+            f'working dir: {self.working_dir}, contents: {self.action_names})'
+        )
+
+    def panic_wrong_action_type(
+        self,
+        action: Any,
+        compatible_type: str,
+    ) -> Never:
+        raise WrongActionTypeError(
+            (
+                f'Action must be of type {compatible_type}, '
+                f'but is of type >>{type(action)}<<.'
+            )
+        )
+
+    def prep_run(self) -> None:
+        logger.info('Starting pipeline >>%s<<...', self.name)
+        # progress tracking
+        self.curr_proc_idx = 1
+        # check if performable actions available
+        if len(self.actions) == 0:
+            raise NoPerformableActionError(
+                'The pipeline does not contain any performable actions.'
+            )
+
+    def post_run(self) -> None:
+        logger.info(
+            'Processing pipeline >>%s<< successfully ended after %d steps.',
+            self.name,
+            (self.curr_proc_idx - 1),
+        )
+
+    @abstractmethod
+    def add(self) -> None: ...
+
+    @abstractmethod
+    def logic(self) -> None: ...
+
+    def run(self, *args, **kwargs) -> Any:
+        self.prep_run()
+        ret = self.logic(*args, **kwargs)
+        self.post_run()
+        return ret
+
+

Helper class that provides a standard way to create an ABC using +inheritance.

+

Ancestors

+
    +
  • abc.ABC
  • +
+

Subclasses

+ +

Methods

+
+
+def add(self) ‑> None +
+
+
+ +Expand source code + +
@abstractmethod
+def add(self) -> None: ...
+
+
+
+
+def logic(self) ‑> None +
+
+
+ +Expand source code + +
@abstractmethod
+def logic(self) -> None: ...
+
+
+
+
+def panic_wrong_action_type(self, action: Any, compatible_type: str) ‑> Never +
+
+
+ +Expand source code + +
def panic_wrong_action_type(
+    self,
+    action: Any,
+    compatible_type: str,
+) -> Never:
+    raise WrongActionTypeError(
+        (
+            f'Action must be of type {compatible_type}, '
+            f'but is of type >>{type(action)}<<.'
+        )
+    )
+
+
+
+
+def post_run(self) ‑> None +
+
+
+ +Expand source code + +
def post_run(self) -> None:
+    logger.info(
+        'Processing pipeline >>%s<< successfully ended after %d steps.',
+        self.name,
+        (self.curr_proc_idx - 1),
+    )
+
+
+
+
+def prep_run(self) ‑> None +
+
+
+ +Expand source code + +
def prep_run(self) -> None:
+    logger.info('Starting pipeline >>%s<<...', self.name)
+    # progress tracking
+    self.curr_proc_idx = 1
+    # check if performable actions available
+    if len(self.actions) == 0:
+        raise NoPerformableActionError(
+            'The pipeline does not contain any performable actions.'
+        )
+
+
+
+
+def run(self, *args, **kwargs) ‑> Any +
+
+
+ +Expand source code + +
def run(self, *args, **kwargs) -> Any:
+    self.prep_run()
+    ret = self.logic(*args, **kwargs)
+    self.post_run()
+    return ret
+
+
+
+
+
+
+class Pipeline +(name: str, working_dir: Path) +
+
+
+ +Expand source code + +
class Pipeline(BasePipeline):
+    def __init__(
+        self,
+        name: str,
+        working_dir: Path,
+    ) -> None:
+        # init base class
+        super().__init__(name=name, working_dir=working_dir)
+        # name of pipeline
+        self.name = name
+        # working directory for pipeline == output path
+        self.working_dir = working_dir
+        # container for actions to perform during pass
+        self.actions_kwargs: list[dict[str, Any]] = []
+        self.save_results: ResultHandling = []
+        self.load_results: ResultHandling = []
+        # intermediate result
+        self._intermediate_result: tuple[Any, ...] | None = None
+
+    def __repr__(self) -> str:
+        return (
+            f'{self.__class__.__name__}(name: {self.name}, '
+            f'working dir: {self.working_dir}, contents: {self.action_names})'
+        )
+
+    @override
+    def add(
+        self,
+        action: Callable,
+        action_kwargs: dict[str, Any] | None = None,
+        skip: bool = False,
+        save_result: bool = False,
+        load_result: bool = False,
+        filename: str | None = None,
+    ) -> None:
+        # check explicitly for function type
+        # if isinstance(action, FunctionType):
+        if action_kwargs is None:
+            action_kwargs = {}
+
+        if isinstance(action, Callable):
+            self.actions.append(action)
+            self.action_names.append(action.__name__)
+            self.actions_kwargs.append(action_kwargs.copy())
+            self.action_skip.append(skip)
+            self.save_results.append((save_result, filename))
+            self.load_results.append((load_result, filename))
+        else:
+            self.panic_wrong_action_type(action=action, compatible_type=Callable.__name__)
+
+    def get_result_path(
+        self,
+        action_idx: int,
+        filename: str | None,
+    ) -> tuple[Path, str]:
+        action_name = self.action_names[action_idx]
+        if filename is None:
+            target_filename = f'Pipe-{self.name}_Step-{self.curr_proc_idx}_{action_name}'
+        else:
+            target_filename = filename
+        target_path = self.working_dir.joinpath(target_filename).with_suffix('.pkl')
+        return target_path, action_name
+
+    def load_step(
+        self,
+        action_idx: int,
+        filename: str | None,
+    ) -> tuple[Any, ...]:
+        target_path, action_name = self.get_result_path(action_idx, filename)
+
+        if not target_path.exists():
+            raise FileNotFoundError(
+                (
+                    f'No intermediate results for action >>{action_name}<< '
+                    f'under >>{target_path}<< found'
+                )
+            )
+        # results should be tuple, but that is not guaranteed
+        result_loaded = cast(tuple[Any, ...], load_pickle(target_path))
+        if not isinstance(result_loaded, tuple):
+            raise TypeError(f'Loaded results must be tuple, not {type(result_loaded)}')
+
+        return result_loaded
+
+    def save_step(
+        self,
+        action_idx: int,
+        filename: str | None,
+    ) -> None:
+        target_path, _ = self.get_result_path(action_idx, filename)
+        save_pickle(obj=self._intermediate_result, path=target_path)
+
+    @override
+    def logic(
+        self,
+        starting_values: tuple[Any, ...] | None = None,
+    ) -> tuple[Any, ...]:
+        first_performed: bool = False
+
+        for idx, (action, action_kwargs) in enumerate(zip(self.actions, self.actions_kwargs)):
+            if self.action_skip[idx]:
+                self.curr_proc_idx += 1
+                continue
+
+            # loading
+            if self.load_results[idx][0]:
+                filename = self.load_results[idx][1]
+                ret = self.load_step(action_idx=idx, filename=filename)
+                self._intermediate_result = ret
+                logger.info(
+                    '[No Calculation] Loaded result for action >>%s<< successfully',
+                    self.action_names[idx],
+                )
+                self.curr_proc_idx += 1
+                continue
+            # calculation
+            if not first_performed:
+                args = starting_values
+                first_performed = True
+            else:
+                args = ret
+
+            if args is not None:
+                ret = action(*args, **action_kwargs)
+            else:
+                ret = action(**action_kwargs)
+
+            if ret is not None and not isinstance(ret, tuple):
+                ret = (ret,)
+            ret = cast(tuple[Any, ...], ret)
+            # save intermediate result
+            self._intermediate_result = ret
+            # saving result locally, always save last action
+            if self.save_results[idx][0] or idx == (len(self.actions) - 1):
+                filename = self.save_results[idx][1]
+                self.save_step(action_idx=idx, filename=filename)
+            # processing tracking
+            self.curr_proc_idx += 1
+
+        return ret
+
+

Helper class that provides a standard way to create an ABC using +inheritance.

+

Ancestors

+ +

Methods

+
+
+def add(self,
action: Callable,
action_kwargs: dict[str, Any] | None = None,
skip: bool = False,
save_result: bool = False,
load_result: bool = False,
filename: str | None = None) ‑> None
+
+
+
+ +Expand source code + +
@override
+def add(
+    self,
+    action: Callable,
+    action_kwargs: dict[str, Any] | None = None,
+    skip: bool = False,
+    save_result: bool = False,
+    load_result: bool = False,
+    filename: str | None = None,
+) -> None:
+    # check explicitly for function type
+    # if isinstance(action, FunctionType):
+    if action_kwargs is None:
+        action_kwargs = {}
+
+    if isinstance(action, Callable):
+        self.actions.append(action)
+        self.action_names.append(action.__name__)
+        self.actions_kwargs.append(action_kwargs.copy())
+        self.action_skip.append(skip)
+        self.save_results.append((save_result, filename))
+        self.load_results.append((load_result, filename))
+    else:
+        self.panic_wrong_action_type(action=action, compatible_type=Callable.__name__)
+
+
+
+
+def get_result_path(self, action_idx: int, filename: str | None) ‑> tuple[pathlib.Path, str] +
+
+
+ +Expand source code + +
def get_result_path(
+    self,
+    action_idx: int,
+    filename: str | None,
+) -> tuple[Path, str]:
+    action_name = self.action_names[action_idx]
+    if filename is None:
+        target_filename = f'Pipe-{self.name}_Step-{self.curr_proc_idx}_{action_name}'
+    else:
+        target_filename = filename
+    target_path = self.working_dir.joinpath(target_filename).with_suffix('.pkl')
+    return target_path, action_name
+
+
+
+
+def load_step(self, action_idx: int, filename: str | None) ‑> tuple[typing.Any, ...] +
+
+
+ +Expand source code + +
def load_step(
+    self,
+    action_idx: int,
+    filename: str | None,
+) -> tuple[Any, ...]:
+    target_path, action_name = self.get_result_path(action_idx, filename)
+
+    if not target_path.exists():
+        raise FileNotFoundError(
+            (
+                f'No intermediate results for action >>{action_name}<< '
+                f'under >>{target_path}<< found'
+            )
+        )
+    # results should be tuple, but that is not guaranteed
+    result_loaded = cast(tuple[Any, ...], load_pickle(target_path))
+    if not isinstance(result_loaded, tuple):
+        raise TypeError(f'Loaded results must be tuple, not {type(result_loaded)}')
+
+    return result_loaded
+
+
+
+
+def logic(self, starting_values: tuple[Any, ...] | None = None) ‑> tuple[typing.Any, ...] +
+
+
+ +Expand source code + +
@override
+def logic(
+    self,
+    starting_values: tuple[Any, ...] | None = None,
+) -> tuple[Any, ...]:
+    first_performed: bool = False
+
+    for idx, (action, action_kwargs) in enumerate(zip(self.actions, self.actions_kwargs)):
+        if self.action_skip[idx]:
+            self.curr_proc_idx += 1
+            continue
+
+        # loading
+        if self.load_results[idx][0]:
+            filename = self.load_results[idx][1]
+            ret = self.load_step(action_idx=idx, filename=filename)
+            self._intermediate_result = ret
+            logger.info(
+                '[No Calculation] Loaded result for action >>%s<< successfully',
+                self.action_names[idx],
+            )
+            self.curr_proc_idx += 1
+            continue
+        # calculation
+        if not first_performed:
+            args = starting_values
+            first_performed = True
+        else:
+            args = ret
+
+        if args is not None:
+            ret = action(*args, **action_kwargs)
+        else:
+            ret = action(**action_kwargs)
+
+        if ret is not None and not isinstance(ret, tuple):
+            ret = (ret,)
+        ret = cast(tuple[Any, ...], ret)
+        # save intermediate result
+        self._intermediate_result = ret
+        # saving result locally, always save last action
+        if self.save_results[idx][0] or idx == (len(self.actions) - 1):
+            filename = self.save_results[idx][1]
+            self.save_step(action_idx=idx, filename=filename)
+        # processing tracking
+        self.curr_proc_idx += 1
+
+    return ret
+
+
+
+
+def save_step(self, action_idx: int, filename: str | None) ‑> None +
+
+
+ +Expand source code + +
def save_step(
+    self,
+    action_idx: int,
+    filename: str | None,
+) -> None:
+    target_path, _ = self.get_result_path(action_idx, filename)
+    save_pickle(obj=self._intermediate_result, path=target_path)
+
+
+
+
+
+
+class PipelineContainer +(name: str, working_dir: Path) +
+
+
+ +Expand source code + +
class PipelineContainer(BasePipeline):
+    """Container class for basic actions.
+    Basic actions are usually functions, which do not take any parameters
+    and return nothing. Indeed, if an action returns any values after its
+    procedure is finished, an error is raised. Therefore, PipelineContainers
+    can be seen as a concatenation of many (independent) simple procedures
+    which are executed in the order in which they were added to the pipe.
+    With a simple call of the ``run`` method the actions are performed.
+    Additionally, there is an option to skip actions which can be set in
+    the ``add`` method. This allows for easily configurable pipelines,
+    e.g., via a user configuration.
+    """
+
+    def __init__(
+        self,
+        name: str,
+        working_dir: Path,
+    ) -> None:
+        super().__init__(name=name, working_dir=working_dir)
+
+    @override
+    def add(
+        self,
+        action: Callable,
+        skip: bool = False,
+    ) -> None:
+        if isinstance(action, Callable):
+            self.actions.append(action)
+            self.action_names.append(action.__name__)
+            self.action_skip.append(skip)
+        else:
+            self.panic_wrong_action_type(action=action, compatible_type=Callable.__name__)
+
+    @override
+    def logic(self) -> None:
+        for idx, (action, action_name) in enumerate(zip(self.actions, self.action_names)):
+            # loading
+            if self.action_skip[idx]:
+                logger.info('[No Calculation] Skipping >>%s<<...', action_name)
+                self.curr_proc_idx += 1
+                continue
+            # calculation
+            ret = action()
+            if ret is not None:
+                raise OutputInPipelineContainerError(
+                    (
+                        f'Output in PipelineContainers not allowed. Action {action_name} '
+                        f'returned values in Container {self.name}.'
+                    )
+                )
+            # processing tracking
+            self.curr_proc_idx += 1
+
+

Container class for basic actions. +Basic actions are usually functions, which do not take any parameters +and return nothing. Indeed, if an action returns any values after its +procedure is finished, an error is raised. Therefore, PipelineContainers +can be seen as a concatenation of many (independent) simple procedures +which are executed in the order in which they were added to the pipe. +With a simple call of the run method the actions are performed. +Additionally, there is an option to skip actions which can be set in +the add method. This allows for easily configurable pipelines, +e.g., via a user configuration.

+

Ancestors

+ +

Methods

+
+
+def add(self, action: Callable, skip: bool = False) ‑> None +
+
+
+ +Expand source code + +
@override
+def add(
+    self,
+    action: Callable,
+    skip: bool = False,
+) -> None:
+    if isinstance(action, Callable):
+        self.actions.append(action)
+        self.action_names.append(action.__name__)
+        self.action_skip.append(skip)
+    else:
+        self.panic_wrong_action_type(action=action, compatible_type=Callable.__name__)
+
+
+
+
+def logic(self) ‑> None +
+
+
+ +Expand source code + +
@override
+def logic(self) -> None:
+    for idx, (action, action_name) in enumerate(zip(self.actions, self.action_names)):
+        # loading
+        if self.action_skip[idx]:
+            logger.info('[No Calculation] Skipping >>%s<<...', action_name)
+            self.curr_proc_idx += 1
+            continue
+        # calculation
+        ret = action()
+        if ret is not None:
+            raise OutputInPipelineContainerError(
+                (
+                    f'Output in PipelineContainers not allowed. Action {action_name} '
+                    f'returned values in Container {self.name}.'
+                )
+            )
+        # processing tracking
+        self.curr_proc_idx += 1
+
+
+
+
+
+
+
+
+ +
+ + + diff --git a/docs/lang_main/pipelines/index.html b/docs/lang_main/pipelines/index.html new file mode 100644 index 0000000..810682b --- /dev/null +++ b/docs/lang_main/pipelines/index.html @@ -0,0 +1,83 @@ + + + + + + +lang_main.pipelines API documentation + + + + + + + + + + + +
+ + +
+ + + diff --git a/docs/lang_main/pipelines/predefined.html b/docs/lang_main/pipelines/predefined.html new file mode 100644 index 0000000..e66d89a --- /dev/null +++ b/docs/lang_main/pipelines/predefined.html @@ -0,0 +1,386 @@ + + + + + + +lang_main.pipelines.predefined API documentation + + + + + + + + + + + +
+
+
+

Module lang_main.pipelines.predefined

+
+
+
+
+
+
+
+
+

Functions

+
+
+def build_base_target_feature_pipe() ‑> Pipeline +
+
+
+ +Expand source code + +
def build_base_target_feature_pipe() -> Pipeline:
+    pipe_target_feat = Pipeline(name='Target_Feature', working_dir=SAVE_PATH_FOLDER)
+    pipe_target_feat.add(
+        load_raw_data,
+        {
+            'date_cols': DATE_COLS,
+        },
+    )
+    pipe_target_feat.add(remove_duplicates)
+    pipe_target_feat.add(remove_NA, save_result=True)
+    pipe_target_feat.add(
+        entry_wise_cleansing,
+        {
+            'target_features': (TARGET_FEATURE,),
+            'cleansing_func': clean_string_slim,
+        },
+        save_result=True,
+        filename=EntryPoints.TIMELINE,
+    )
+    pipe_target_feat.add(
+        analyse_feature,
+        {
+            'target_feature': TARGET_FEATURE,
+        },
+        save_result=True,
+    )
+
+    return pipe_target_feat
+
+
+
+
+def build_merge_duplicates_pipe() ‑> Pipeline +
+
+
+ +Expand source code + +
def build_merge_duplicates_pipe() -> Pipeline:
+    pipe_merge = Pipeline(name='Merge_Duplicates', working_dir=SAVE_PATH_FOLDER)
+    pipe_merge.add(
+        numeric_pre_filter_feature,
+        {
+            'feature': 'len',
+            'bound_lower': THRESHOLD_AMOUNT_CHARACTERS,
+            'bound_upper': None,
+        },
+    )
+    pipe_merge.add(
+        merge_similarity_duplicates,
+        {
+            'model': STFR_MODEL,
+            'cos_sim_threshold': THRESHOLD_SIMILARITY,
+        },
+        save_result=True,
+        filename=EntryPoints.TOKEN_ANALYSIS,
+    )
+
+    return pipe_merge
+
+
+
+
+def build_timeline_pipe() ‑> Pipeline +
+
+
+ +Expand source code + +
def build_timeline_pipe() -> Pipeline:
+    pipe_timeline = Pipeline(name='Timeline_Analysis', working_dir=SAVE_PATH_FOLDER)
+    pipe_timeline.add(
+        cleanup_descriptions,
+        {
+            'properties': ['ErledigungsBeschreibung'],
+        },
+    )
+    pipe_timeline.add(
+        calc_delta_to_repair,
+        {
+            'date_feature_start': 'ErstellungsDatum',
+            'date_feature_end': 'ErledigungsDatum',
+            'name_delta_feature': NAME_DELTA_FEAT_TO_REPAIR,
+            'convert_to_days': True,
+        },
+        save_result=True,
+        filename=EntryPoints.TIMELINE_POST,
+    )
+    pipe_timeline.add(
+        remove_non_relevant_obj_ids,
+        {
+            'thresh_unique_feat_per_id': THRESHOLD_UNIQUE_TEXTS,
+            'feature_uniqueness': UNIQUE_CRITERION_FEATURE,
+            'feature_obj_id': FEATURE_NAME_OBJ_ID,
+        },
+        save_result=True,
+    )
+    pipe_timeline.add(
+        generate_model_input,
+        {
+            'target_feature_name': 'nlp_model_input',
+            'model_input_features': MODEL_INPUT_FEATURES,
+        },
+    )
+    pipe_timeline.add(
+        filter_activities_per_obj_id,
+        {
+            'activity_feature': ACTIVITY_FEATURE,
+            'relevant_activity_types': ACTIVITY_TYPES,
+            'feature_obj_id': FEATURE_NAME_OBJ_ID,
+            'threshold_num_activities': THRESHOLD_NUM_ACTIVITIES,
+        },
+    )
+    pipe_timeline.add(
+        get_timeline_candidates,
+        {
+            'model': STFR_MODEL,
+            'cos_sim_threshold': THRESHOLD_TIMELINE_SIMILARITY,
+            'feature_obj_id': FEATURE_NAME_OBJ_ID,
+            'feature_obj_text': FEATURE_NAME_OBJ_TEXT,
+            'model_input_feature': 'nlp_model_input',
+        },
+        save_result=True,
+        filename=EntryPoints.TIMELINE_CANDS,
+    )
+
+    return pipe_timeline
+
+
+
+
+def build_tk_graph_pipe() ‑> Pipeline +
+
+
+ +Expand source code + +
def build_tk_graph_pipe() -> Pipeline:
+    pipe_token_analysis = Pipeline(name='Token_Analysis', working_dir=SAVE_PATH_FOLDER)
+    pipe_token_analysis.add(
+        build_token_graph,
+        {
+            'model': SPACY_MODEL,
+            'target_feature': 'entry',
+            'weights_feature': 'num_occur',
+            'batch_idx_feature': 'batched_idxs',
+            'build_map': False,
+            'batch_size_model': 50,
+        },
+        save_result=True,
+        filename=EntryPoints.TK_GRAPH_POST,
+    )
+
+    return pipe_token_analysis
+
+
+
+
+def build_tk_graph_post_pipe() ‑> Pipeline +
+
+
+ +Expand source code + +
def build_tk_graph_post_pipe() -> Pipeline:
+    pipe_graph_postprocessing = Pipeline(
+        name='Graph_Postprocessing', working_dir=SAVE_PATH_FOLDER
+    )
+    pipe_graph_postprocessing.add(
+        graphs.filter_graph_by_number_edges,
+        {
+            'limit': MAX_EDGE_NUMBER,
+            'property': 'weight',
+        },
+    )
+    pipe_graph_postprocessing.add(
+        graphs.filter_graph_by_node_degree,
+        {
+            'bound_lower': 1,
+            'bound_upper': None,
+        },
+    )
+    pipe_graph_postprocessing.add(
+        graphs.static_graph_analysis,
+        save_result=True,
+        filename=EntryPoints.TK_GRAPH_ANALYSIS,
+    )
+
+    return pipe_graph_postprocessing
+
+
+
+
+def build_tk_graph_render_pipe(with_subgraphs: bool,
export_folder: pathlib.Path = WindowsPath('A:/Arbeitsaufgaben/lang-data/out'),
base_network_name: str = 'token_graph') ‑> Pipeline
+
+
+
+ +Expand source code + +
def build_tk_graph_render_pipe(
+    with_subgraphs: bool,
+    export_folder: Path = SAVE_PATH_FOLDER,
+    base_network_name: str = CYTO_BASE_NETWORK_NAME,
+) -> Pipeline:
+    # optional dependency: late import
+    # raises exception if necessary modules are not found
+    try:
+        from lang_main.render import cytoscape as cyto
+    except ImportError:
+        raise ImportError(
+            (
+                'Dependencies for Cytoscape interaction not found.'
+                'Install package with optional dependencies.'
+            )
+        )
+
+    pipe_graph_rendering = Pipeline(
+        name='Graph_Static-Rendering',
+        working_dir=SAVE_PATH_FOLDER,
+    )
+    pipe_graph_rendering.add(
+        cyto.import_to_cytoscape,
+        {
+            'network_name': base_network_name,
+        },
+    )
+    pipe_graph_rendering.add(
+        cyto.layout_network,
+        {
+            'network_name': base_network_name,
+        },
+    )
+    pipe_graph_rendering.add(
+        cyto.apply_style_to_network,
+        {
+            'network_name': base_network_name,
+        },
+    )
+    pipe_graph_rendering.add(
+        cyto.export_network_to_image,
+        {
+            'filename': base_network_name,
+            'target_folder': export_folder,
+            'network_name': base_network_name,
+        },
+    )
+
+    if with_subgraphs:
+        pipe_graph_rendering.add(
+            cyto.get_subgraph_node_selection,
+            {
+                'network_name': base_network_name,
+            },
+        )
+        pipe_graph_rendering.add(
+            cyto.build_subnetworks,
+            {
+                'export_image': True,
+                'target_folder': export_folder,
+                'network_name': base_network_name,
+            },
+        )
+
+    return pipe_graph_rendering
+
+
+
+
+def build_tk_graph_rescaling_pipe(save_result: bool, exit_point: lang_main.types.EntryPoints) ‑> Pipeline +
+
+
+ +Expand source code + +
def build_tk_graph_rescaling_pipe(
+    save_result: bool,
+    exit_point: EntryPoints,
+) -> Pipeline:
+    pipe_graph_rescaling = Pipeline(name='Graph_Rescaling', working_dir=SAVE_PATH_FOLDER)
+    pipe_graph_rescaling.add(
+        graphs.pipe_rescale_graph_edge_weights,
+    )
+    pipe_graph_rescaling.add(
+        graphs.pipe_add_graph_metrics,
+        save_result=save_result,
+        filename=exit_point,
+    )
+
+    return pipe_graph_rescaling
+
+
+
+
+
+
+
+
+ +
+ + + diff --git a/docs/lang_main/render/cytoscape.html b/docs/lang_main/render/cytoscape.html new file mode 100644 index 0000000..f9c643a --- /dev/null +++ b/docs/lang_main/render/cytoscape.html @@ -0,0 +1,797 @@ + + + + + + +lang_main.render.cytoscape API documentation + + + + + + + + + + + +
+
+
+

Module lang_main.render.cytoscape

+
+
+
+
+
+
+
+
+

Functions

+
+
+def analyse_network(property_degree_weighted: str = 'degree_weighted',
network_name: str = 'token_graph') ‑> None
+
+
+
+ +Expand source code + +
def analyse_network(
+    property_degree_weighted: str = PROPERTY_NAME_DEGREE_WEIGHTED,
+    network_name: str = CYTO_BASE_NETWORK_NAME,
+) -> None:
+    node_table = p4c.get_table_columns(table='node', network=network_name)
+    net_analyse_possible: bool = True
+    if len(node_table) < 4:  # pragma: no cover
+        net_analyse_possible = False
+
+    if net_analyse_possible:
+        p4c.analyze_network(directed=False)
+        node_table = p4c.get_table_columns(table='node', network=network_name)
+        node_table['stress_norm'] = node_table['Stress'] / node_table['Stress'].max()
+        node_table[CYTO_SELECTION_PROPERTY] = (
+            node_table[property_degree_weighted]
+            * node_table['BetweennessCentrality']
+            * node_table['stress_norm']
+        )
+    else:  # pragma: no cover
+        node_table[CYTO_SELECTION_PROPERTY] = 1
+
+    p4c.load_table_data(node_table, data_key_column='name', network=network_name)
+
+
+
+
+def apply_style_to_network(style_name: str = 'lang_main',
pth_to_stylesheet: pathlib.Path = WindowsPath('A:/Arbeitsaufgaben/lang-main/src/lang_main/cytoscape_config/lang_main.xml'),
network_name: str = 'token_graph',
node_size_property: str = 'node_selection',
min_node_size: int = 15,
max_node_size: int = 40,
sandbox_name: str = 'lang_main') ‑> None
+
+
+
+ +Expand source code + +
def apply_style_to_network(
+    style_name: str = CYTO_STYLESHEET_NAME,
+    pth_to_stylesheet: Path = CYTO_PATH_STYLESHEET,
+    network_name: str = CYTO_BASE_NETWORK_NAME,
+    node_size_property: str = CYTO_SELECTION_PROPERTY,
+    min_node_size: int = 15,
+    max_node_size: int = 40,
+    sandbox_name: str = CYTO_SANDBOX_NAME,
+) -> None:
+    """Cytoscape: apply a chosen Cytoscape style to the defined network
+
+    Parameters
+    ----------
+    style_name : str, optional
+        Cytoscape name of the style which should be applied,
+        by default CYTO_STYLESHEET_NAME
+    pth_to_stylesheet : Path, optional
+        path where the stylesheet definition in Cytoscape's XML format can
+        be found,
+        by default CYTO_PATH_STYLESHEET
+    network_name : str, optional
+        network to apply the style on, by default CYTO_BASE_NETWORK_NAME
+
+    Raises
+    ------
+    FileNotFoundError
+        if provided stylesheet can not be found under the provided path
+    """
+    logger.debug('Applying style to network...')
+    styles_avail = cast(list[str], p4c.get_visual_style_names())
+    logger.debug('Available styles: %s', styles_avail)
+    if style_name not in styles_avail:
+        if not pth_to_stylesheet.exists():
+            # existence for standard path verified at import, but not for other
+            # provided paths
+            raise FileNotFoundError(
+                f'Visual stylesheet for Cytoscape not found under: >>{pth_to_stylesheet}<<'
+            )
+        # send to sandbox
+        sandbox_filename = pth_to_stylesheet.name
+        p4c.sandbox_send_to(
+            source_file=pth_to_stylesheet,
+            dest_file=sandbox_filename,
+            overwrite=True,
+            sandbox_name=sandbox_name,
+        )
+        # load stylesheet
+        p4c.import_visual_styles(sandbox_filename)
+
+    p4c.set_visual_style(style_name, network=network_name)
+    # node size mapping, only if needed property is available
+    scheme = p4c.scheme_c_number_continuous(
+        start_value=min_node_size, end_value=max_node_size
+    )
+    node_size_map = p4c.gen_node_size_map(
+        node_size_property,
+        number_scheme=scheme,
+        mapping_type='c',
+        style_name=style_name,
+        default_number=min_node_size,
+    )
+    p4c.set_node_size_mapping(**node_size_map)
+    fit_content(network_name=network_name)
+    logger.debug('Style application to network successful.')
+
+

Cytoscape: apply a chosen Cytoscape style to the defined network

+

Parameters

+
+
style_name : str, optional
+
Cytoscape name of the style which should be applied, +by default CYTO_STYLESHEET_NAME
+
pth_to_stylesheet : Path, optional
+
path where the stylesheet definition in Cytoscape's XML format can +be found, +by default CYTO_PATH_STYLESHEET
+
network_name : str, optional
+
network to apply the style on, by default CYTO_BASE_NETWORK_NAME
+
+

Raises

+
+
FileNotFoundError
+
if provided stylesheet can not be found under the provided path
+
+
+
+def build_subnetworks(nodes_to_analyse: Iterable[int],
network_name: str = 'token_graph',
export_image: bool = True,
target_folder: pathlib.Path = WindowsPath('A:/Arbeitsaufgaben/lang-data/out')) ‑> None
+
+
+
+ +Expand source code + +
def build_subnetworks(
+    nodes_to_analyse: Iterable[CytoNodeID],
+    network_name: str = CYTO_BASE_NETWORK_NAME,
+    export_image: bool = True,
+    target_folder: Path = SAVE_PATH_FOLDER,
+) -> None:
+    """Cytoscape: iteratively build subnetworks from a collection of nodes
+    and their respective neighbouring nodes
+
+    Parameters
+    ----------
+    nodes_to_analyse : Iterable[CytoNodeID]
+        collection of nodes to make subnetworks from, for each node a dedicated
+        subnetwork will be generated
+    network_name : str, optional
+        network which contains the provided nodes,
+        by default CYTO_BASE_NETWORK_NAME
+    export_image : bool, optional
+        trigger image export of newly generated subnetworks, by default True
+    """
+    logger.debug('Generating all subnetworks for node selection...')
+    for idx, node in enumerate(nodes_to_analyse):
+        select_neighbours_of_node(node=node, network_name=network_name)
+        make_subnetwork(
+            index=idx,
+            network_name=network_name,
+            export_image=export_image,
+            target_folder=target_folder,
+        )
+    logger.debug('Generation of all subnetworks for node selection successful.')
+
+

Cytoscape: iteratively build subnetworks from a collection of nodes +and their respective neighbouring nodes

+

Parameters

+
+
nodes_to_analyse : Iterable[CytoNodeID]
+
collection of nodes to make subnetworks from, for each node a dedicated +subnetwork will be generated
+
network_name : str, optional
+
network which contains the provided nodes, +by default CYTO_BASE_NETWORK_NAME
+
export_image : bool, optional
+
trigger image export of newly generated subnetworks, by default True
+
+
+
+def change_default_layout() ‑> None +
+
+
+ +Expand source code + +
def change_default_layout() -> None:
+    """Cytoscape: resets the default layout to `grid` to accelerate the import process
+    (grid layout one of the fastest)
+
+    Raises
+    ------
+    RequestException
+        API endpoint not reachable or CyREST operation not successful
+    """
+    body: dict[str, str] = {'value': 'grid', 'key': 'layout.default'}
+    try:
+        p4c.cyrest_put('properties/cytoscape3.props/layout.default', body=body)
+    except RequestException as error:
+        logger.error('[CytoAPIConnection] Property change of default layout not successful.')
+        raise error
+
+

Cytoscape: resets the default layout to grid to accelerate the import process +(grid layout one of the fastest)

+

Raises

+
+
RequestException
+
API endpoint not reachable or CyREST operation not successful
+
+
+
+def export_network_to_image(filename: str,
target_folder: pathlib.Path = WindowsPath('A:/Arbeitsaufgaben/lang-data/out'),
filetype: Literal['JPEG', 'PDF', 'PNG', 'PS', 'SVG'] = 'SVG',
network_name: str = 'token_graph',
pdf_export_page_size: Literal['A0', 'A1', 'A2', 'A3', 'A4', 'A5', 'Auto', 'Legal', 'Letter', 'Tabloid'] = 'A4',
sandbox_name: str = 'lang_main') ‑> None
+
+
+
+ +Expand source code + +
def export_network_to_image(
+    filename: str,
+    target_folder: Path = SAVE_PATH_FOLDER,
+    filetype: CytoExportFileTypes = 'SVG',
+    network_name: str = CYTO_BASE_NETWORK_NAME,
+    pdf_export_page_size: CytoExportPageSizes = 'A4',
+    sandbox_name: str = CYTO_SANDBOX_NAME,
+) -> None:
+    """Cytoscape: export current selected view as image
+
+    Parameters
+    ----------
+    filename : str
+        export filename
+    filetype : CytoExportFileTypes, optional
+        export filetype supported by Cytoscape, by default 'SVG'
+    network_name : str, optional
+        network to export, by default CYTO_BASE_NETWORK_NAME
+    pdf_export_page_size : CytoExportPageSizes, optional
+        page size which should be used for PDF exports supported by Cytoscape,
+        by default 'A4'
+    """
+    logger.debug('Exporting image to file...')
+    if not target_folder.exists():  # pragma: no cover
+        target_folder.mkdir(parents=True)
+    dst_file_pth = (target_folder / filename).with_suffix(f'.{filetype.lower()}')
+
+    text_as_font = True
+    if filetype == 'SVG':
+        text_as_font = False
+
+    # close non-necessary windows and fit graph in frame before image display
+    fit_content(network_name=network_name)
+    # image is generated in sandbox directory and transferred to target destination
+    # (preparation for remote instances of Cytoscape)
+    p4c.export_image(
+        filename=filename,
+        type=filetype,
+        network=network_name,
+        overwrite_file=True,
+        all_graphics_details=True,
+        export_text_as_font=text_as_font,
+        page_size=pdf_export_page_size,
+    )
+    logger.debug('Exported image to sandbox.')
+    logger.debug('Transferring image from sandbox to target destination...')
+    sandbox_filename = f'{filename}.{filetype.lower()}'
+    p4c.sandbox_get_from(
+        source_file=sandbox_filename,
+        dest_file=str(dst_file_pth),
+        overwrite=True,
+        sandbox_name=sandbox_name,
+    )
+    logger.debug('Transfer of image from sandbox to target destination successful.')
+
+

Cytoscape: export current selected view as image

+

Parameters

+
+
filename : str
+
export filename
+
filetype : CytoExportFileTypes, optional
+
export filetype supported by Cytoscape, by default 'SVG'
+
network_name : str, optional
+
network to export, by default CYTO_BASE_NETWORK_NAME
+
pdf_export_page_size : CytoExportPageSizes, optional
+
page size which should be used for PDF exports supported by Cytoscape, +by default 'A4'
+
+
+
+def fit_content(zoom_factor: float = 0.96, network_name: str = 'token_graph') ‑> None +
+
+
+ +Expand source code + +
def fit_content(
+    zoom_factor: float = CYTO_NETWORK_ZOOM_FACTOR,
+    network_name: str = CYTO_BASE_NETWORK_NAME,
+) -> None:
+    p4c.hide_all_panels()
+    p4c.fit_content(selected_only=False, network=network_name)
+    zoom_current = p4c.get_network_zoom(network=network_name)
+    zoom_new = zoom_current * zoom_factor
+    p4c.set_network_zoom_bypass(zoom_new, bypass=False, network=network_name)
+
+
+
+
+def get_subgraph_node_selection(network_name: str = 'token_graph', num_subgraphs: int = 5) ‑> list[int] +
+
+
+ +Expand source code + +
def get_subgraph_node_selection(
+    network_name: str = CYTO_BASE_NETWORK_NAME,
+    num_subgraphs: int = CYTO_NUMBER_SUBGRAPHS,
+) -> list[CytoNodeID]:
+    """Cytoscape: obtain the relevant nodes for iterative subgraph generation
+
+    Parameters
+    ----------
+    network_name : str, optional
+        network to retrieve the nodes from, by default CYTO_BASE_NETWORK_NAME
+    property_degree_weighted : str, optional
+        property name which contains the weighted degree,
+        by default PROPERTY_NAME_DEGREE_WEIGHTED
+    num_subgraphs : int, optional
+        number of relevant nodes which form the basis to generate subgraphs from,
+        by default CYTO_NUMBER_SUBGRAPHS
+
+    Returns
+    -------
+    list[CytoNodeID]
+        list containing all relevant Cytoscape nodes
+    """
+    logger.debug('Selecting nodes for subgraph generation...')
+    node_table = p4c.get_table_columns(table='node', network=network_name)
+    node_table = node_table.sort_values(by=CYTO_SELECTION_PROPERTY, ascending=False)
+    p4c.load_table_data(node_table, data_key_column='name', network=network_name)
+    node_table_choice = node_table.iloc[:num_subgraphs]
+    logger.debug('Selection of nodes for subgraph generation successful.')
+
+    return node_table_choice['SUID'].to_list()
+
+

Cytoscape: obtain the relevant nodes for iterative subgraph generation

+

Parameters

+
+
network_name : str, optional
+
network to retrieve the nodes from, by default CYTO_BASE_NETWORK_NAME
+
property_degree_weighted : str, optional
+
property name which contains the weighted degree, +by default PROPERTY_NAME_DEGREE_WEIGHTED
+
num_subgraphs : int, optional
+
number of relevant nodes which form the basis to generate subgraphs from, +by default CYTO_NUMBER_SUBGRAPHS
+
+

Returns

+
+
list[CytoNodeID]
+
list containing all relevant Cytoscape nodes
+
+
+
+def import_to_cytoscape(graph: networkx.classes.digraph.DiGraph | networkx.classes.graph.Graph,
network_name: str = 'token_graph',
sandbox_name: str = 'lang_main',
reinitialise_sandbox: bool = True) ‑> None
+
+
+
+ +Expand source code + +
def import_to_cytoscape(
+    graph: DiGraph | Graph,
+    network_name: str = CYTO_BASE_NETWORK_NAME,
+    sandbox_name: str = CYTO_SANDBOX_NAME,
+    reinitialise_sandbox: bool = True,
+) -> None:
+    """Cytoscape: import NetworkX graph as new network collection
+
+    Parameters
+    ----------
+    graph : DiGraph | Graph
+        NetworkX graph object
+    """
+    logger.debug('Checking Cytoscape connection...')
+    verify_connection()
+    logger.debug('Checking graph size for rendering...')
+    verify_graph_render_size(graph)
+    logger.debug('Setting default layout to improve import speed...')
+    change_default_layout()
+    logger.debug('Setting Cytoscape sandbox...')
+    p4c.sandbox_set(
+        sandbox_name=sandbox_name,
+        reinitialize=reinitialise_sandbox,
+        copy_samples=False,
+    )
+    logger.debug('Importing to and analysing network in Cytoscape...')
+    p4c.delete_all_networks()
+    p4c.create_network_from_networkx(
+        graph,
+        title=network_name,
+        collection=CYTO_COLLECTION_NAME,
+    )
+    analyse_network(network_name=network_name)
+    logger.debug('Import and analysis of network to Cytoscape successful.')
+
+

Cytoscape: import NetworkX graph as new network collection

+

Parameters

+
+
graph : DiGraph | Graph
+
NetworkX graph object
+
+
+
+def layout_network(layout_name: Literal['attribute-circle', 'attribute-grid', 'attributes-layout', 'circular', 'cose', 'degree-circle', 'force-directed', 'force-directed-cl', 'fruchterman-rheingold', 'grid', 'hierarchical', 'isom', 'kamada-kawai', 'stacked-node-layout'] = 'force-directed',
layout_properties: dict[str, float | bool] = {'numIterations': 1000, 'defaultSpringCoefficient': 0.0001, 'defaultSpringLength': 45, 'defaultNodeMass': 11, 'isDeterministic': True, 'singlePartition': False},
network_name: str = 'token_graph') ‑> None
+
+
+
+ +Expand source code + +
def layout_network(
+    layout_name: CytoLayouts = CYTO_LAYOUT_NAME,
+    layout_properties: CytoLayoutProperties = CYTO_LAYOUT_PROPERTIES,
+    network_name: str = CYTO_BASE_NETWORK_NAME,
+) -> None:
+    """Cytoscape: apply a supported layout algorithm to currently selected
+    network
+
+    Parameters
+    ----------
+    layout_name : CytoLayouts, optional
+        layout algorithm supported by Cytoscape (name of the CyREST API, does not
+        necessarily match the name in the Cytoscape UI),
+        by default CYTO_LAYOUT_NAME
+    layout_properties : CytoLayoutProperties, optional
+        configuration of parameters for the given layout algorithm,
+        by default CYTO_LAYOUT_PROPERTIES
+    network_name : str, optional
+        network to apply the layout algorithm on, by default CYTO_BASE_NETWORK_NAME
+    """
+    logger.debug('Applying layout to network...')
+    p4c.set_layout_properties(layout_name, layout_properties)
+    p4c.layout_network(layout_name=layout_name, network=network_name)
+    fit_content(network_name=network_name)
+    logger.debug('Layout application to network successful.')
+
+

Cytoscape: apply a supported layout algorithm to currently selected +network

+

Parameters

+
+
layout_name : CytoLayouts, optional
+
layout algorithm supported by Cytoscape (name of the CyREST API, does not +necessarily match the name in the Cytoscape UI), +by default CYTO_LAYOUT_NAME
+
layout_properties : CytoLayoutProperties, optional
+
configuration of parameters for the given layout algorithm, +by default CYTO_LAYOUT_PROPERTIES
+
network_name : str, optional
+
network to apply the layout algorithm on, by default CYTO_BASE_NETWORK_NAME
+
+
+
+def make_subnetwork(index: int,
network_name: str = 'token_graph',
export_image: bool = True,
target_folder: pathlib.Path = WindowsPath('A:/Arbeitsaufgaben/lang-data/out')) ‑> None
+
+
+
+ +Expand source code + +
def make_subnetwork(
+    index: int,
+    network_name: str = CYTO_BASE_NETWORK_NAME,
+    export_image: bool = True,
+    target_folder: Path = SAVE_PATH_FOLDER,
+) -> None:
+    """Cytoscape: generate a new subnetwork based on the currently
+    selected nodes and edges
+
+    Parameters
+    ----------
+    index : int
+        id-like property to identify the subnetwork relative to its parent
+    network_name : str, optional
+        network to generate subnetwork from, by default CYTO_BASE_NETWORK_NAME
+    export_image : bool, optional
+        trigger image export of newly generated subnetwork, by default True
+    """
+    logger.debug('Generating subnetwork with index %d...', index)
+    subnetwork_name = network_name + f'_sub_{index+1}'
+    p4c.create_subnetwork(
+        nodes='selected',
+        edges='selected',
+        subnetwork_name=subnetwork_name,
+        network=network_name,
+    )
+    p4c.set_current_network(subnetwork_name)
+
+    if export_image:
+        time.sleep(1)
+        export_network_to_image(
+            filename=subnetwork_name,
+            target_folder=target_folder,
+            network_name=subnetwork_name,
+        )
+
+    logger.debug('Generation of subnetwork with index %d successful.', index)
+
+

Cytoscape: generate a new subnetwork based on the currently +selected nodes and edges

+

Parameters

+
+
index : int
+
id-like property to identify the subnetwork relative to its parent
+
network_name : str, optional
+
network to generate subnetwork from, by default CYTO_BASE_NETWORK_NAME
+
export_image : bool, optional
+
trigger image export of newly generated subnetwork, by default True
+
+
+
+def reset_current_network_to_base() ‑> None +
+
+
+ +Expand source code + +
def reset_current_network_to_base() -> None:
+    """resets to currently selected network in Cytoscape back to the base one"""
+    p4c.set_current_network(CYTO_BASE_NETWORK_NAME)
+
+

resets to currently selected network in Cytoscape back to the base one

+
+
+def select_neighbours_of_node(node: int, neighbour_iter_depth: int = 2, network_name: str = 'token_graph') ‑> None +
+
+
+ +Expand source code + +
def select_neighbours_of_node(
+    node: CytoNodeID,
+    neighbour_iter_depth: int = CYTO_ITER_NEIGHBOUR_DEPTH,
+    network_name: str = CYTO_BASE_NETWORK_NAME,
+) -> None:
+    """Cytoscape: iterative selection of a node's neighbouring nodes and
+    their connecting edges
+
+    Parameters
+    ----------
+    node : CytoNodeID
+        node which neighbours should be selected
+    neighbour_iter_depth : int, optional
+        indicates how many levels of neighbours should be choosen, e.g. 1 --> only
+        first-level neighbours are considered which are directly connected to the node,
+        2 --> all nodes with iteration depth of 1 are chosen and additionally their
+        direct neighbours,
+        by default CYTO_ITER_NEIGHBOUR_DEPTH
+    network_name : str, optional
+        network to perform action on, by default CYTO_BASE_NETWORK_NAME
+    """
+    logger.debug('Selecting node neighbours for %s...', node)
+    p4c.clear_selection(network=network_name)
+    p4c.select_nodes(node, network=network_name)
+
+    for _ in range(neighbour_iter_depth):
+        _ = p4c.select_first_neighbors(network=network_name)
+
+    _ = p4c.select_edges_connecting_selected_nodes()
+    logger.debug('Selection of node neighbours for %s successful.', node)
+
+

Cytoscape: iterative selection of a node's neighbouring nodes and +their connecting edges

+

Parameters

+
+
node : CytoNodeID
+
node which neighbours should be selected
+
neighbour_iter_depth : int, optional
+
indicates how many levels of neighbours should be choosen, e.g. 1 –> only +first-level neighbours are considered which are directly connected to the node, +2 –> all nodes with iteration depth of 1 are chosen and additionally their +direct neighbours, +by default CYTO_ITER_NEIGHBOUR_DEPTH
+
network_name : str, optional
+
network to perform action on, by default CYTO_BASE_NETWORK_NAME
+
+
+
+def verify_connection() ‑> None +
+
+
+ +Expand source code + +
def verify_connection() -> None:
+    """Cytoscape: checks if CyREST and Cytoscape versions are compatible nad
+    if Cytoscape API endpoint is reachable
+
+    Raises
+    ------
+    CyError
+        incompatible CyREST or Cytoscape versions
+    RequestException
+        API endpoint not reachable
+    """
+    try:
+        p4c.cytoscape_ping()
+    except CyError as error:  # pragma: no cover
+        logger.error('[CyError] CyREST or Cytoscape version not supported.')
+        raise error
+    except RequestException as error:
+        logger.error('[CytoAPIConnection] Connection to CyREST API failed.')
+        raise error
+
+

Cytoscape: checks if CyREST and Cytoscape versions are compatible nad +if Cytoscape API endpoint is reachable

+

Raises

+
+
CyError
+
incompatible CyREST or Cytoscape versions
+
RequestException
+
API endpoint not reachable
+
+
+
+def verify_graph_render_size(graph: networkx.classes.digraph.DiGraph | networkx.classes.graph.Graph,
max_node_count: int | None = 500,
max_edge_count: int | None = 800) ‑> None
+
+
+
+ +Expand source code + +
def verify_graph_render_size(
+    graph: Graph | DiGraph,
+    max_node_count: int | None = CYTO_MAX_NODE_COUNT,
+    max_edge_count: int | None = CYTO_MAX_EDGE_COUNT,
+) -> None:
+    """verify that the graph size can still be handled within an acceptable time
+    frame for rendering in Cytoscape
+
+    Parameters
+    ----------
+    graph : Graph | DiGraph
+        graph to verify
+    max_node_count : int | None, optional
+        maximum allowed number of nodes, by default CYTO_MAX_NODE_COUNT
+    max_edge_count : int | None, optional
+        maximum allowed number of edges, by default CYTO_MAX_EDGE_COUNT
+
+    Raises
+    ------
+    GraphRenderError
+        if any of the provided limits is exceeded
+    """
+    num_nodes = len(graph.nodes)
+    num_edges = len(graph.edges)
+    if max_node_count is not None and num_nodes > max_node_count:
+        raise GraphRenderError(
+            f'Maximum number of nodes for rendering exceeded. '
+            f'Limit {max_node_count}, Counted: {num_nodes}'
+        )
+
+    if max_edge_count is not None and num_edges > max_edge_count:
+        raise GraphRenderError(
+            f'Maximum number of edges for rendering exceeded. '
+            f'Limit {max_edge_count}, Counted: {num_edges}'
+        )
+
+

verify that the graph size can still be handled within an acceptable time +frame for rendering in Cytoscape

+

Parameters

+
+
graph : Graph | DiGraph
+
graph to verify
+
max_node_count : int | None, optional
+
maximum allowed number of nodes, by default CYTO_MAX_NODE_COUNT
+
max_edge_count : int | None, optional
+
maximum allowed number of edges, by default CYTO_MAX_EDGE_COUNT
+
+

Raises

+
+
GraphRenderError
+
if any of the provided limits is exceeded
+
+
+
+def verify_table_property(property: str,
table_type: Literal['node', 'edge', 'network'] = 'node',
network_name: str = 'token_graph') ‑> bool
+
+
+
+ +Expand source code + +
def verify_table_property(
+    property: str,
+    table_type: Literal['node', 'edge', 'network'] = 'node',
+    network_name: str = CYTO_BASE_NETWORK_NAME,
+) -> bool:
+    table = p4c.get_table_columns(table=table_type, network=network_name)
+    logger.debug('Table >>%s<< wiht columns: %s', table, table.columns)
+
+    return property in table.columns
+
+
+
+
+
+
+
+
+ +
+ + + diff --git a/docs/lang_main/render/cytoscape_monkeypatch.html b/docs/lang_main/render/cytoscape_monkeypatch.html new file mode 100644 index 0000000..a7f4893 --- /dev/null +++ b/docs/lang_main/render/cytoscape_monkeypatch.html @@ -0,0 +1,182 @@ + + + + + + +lang_main.render.cytoscape_monkeypatch API documentation + + + + + + + + + + + +
+
+
+

Module lang_main.render.cytoscape_monkeypatch

+
+
+
+
+
+
+
+
+

Functions

+
+
+def select_edges_connecting_selected_nodes(network=None, base_url='http://127.0.0.1:1234/v1') +
+
+
+ +Expand source code + +
@cy_log  # pragma: no cover
+def select_edges_connecting_selected_nodes(network=None, base_url=DEFAULT_BASE_URL):  # noqa: F405 # pragma: no cover
+    """Select edges in a Cytoscape Network connecting the selected nodes, including self loops connecting single nodes.
+
+    Any edges selected beforehand are deselected before any new edges are selected
+
+    Args:
+        network (SUID or str or None): Name or SUID of a network. Default is the
+            "current" network active in Cytoscape.
+        base_url (str): Ignore unless you need to specify a custom domain,
+            port or version to connect to the CyREST API. Default is http://127.0.0.1:1234
+            and the latest version of the CyREST API supported by this version of py4cytoscape.
+
+    Returns:
+         dict: {'nodes': [node list], 'edges': [edge list]} or None if no selected nodes
+    Raises:
+        CyError: if network name or SUID doesn't exist
+        requests.exceptions.RequestException: if can't connect to Cytoscape or Cytoscape returns an error
+
+    Examples:
+        >>> select_edges_connecting_selected_nodes()
+        None
+        >>> select_edges_connecting_selected_nodes(network='My Network')
+        {'nodes': [103990, 103991, ...], 'edges': [104432, 104431, ...]}
+        >>> select_edges_connecting_selected_nodes(network=52)
+        {'nodes': [103990, 103991, ...], 'edges': [104432, 104431, ...]}
+
+    Note:
+        In the return value node list is list of all selected nodes, and
+        edge list is the SUIDs of selected edges -- dict is None if no nodes were selected or there were no newly
+        created edges
+    """
+    net_suid = networks.get_network_suid(network, base_url=base_url)
+
+    selected_nodes = get_selected_nodes(network=net_suid, base_url=base_url)
+    # TODO: In R version, NA test is after len() test ... shouldn't it be before?
+    if not selected_nodes:
+        return None
+
+    all_edges = networks.get_all_edges(net_suid, base_url=base_url)
+
+    selected_sources = set()
+    selected_targets = set()
+    for n in selected_nodes:
+        n = re_parenthesis_1.sub('\(', n)  # type: ignore
+        n = re_parenthesis_2.sub('\)', n)  # type: ignore
+        selected_sources |= set(filter(re.compile('^' + n).search, all_edges))  # type: ignore
+        selected_targets |= set(filter(re.compile(n + '$').search, all_edges))  # type: ignore
+
+    selected_edges = list(selected_sources.intersection(selected_targets))
+
+    if len(selected_edges) == 0:
+        return None
+    res = select_edges(
+        selected_edges,
+        by_col='name',
+        preserve_current_selection=False,
+        network=net_suid,
+        base_url=base_url,
+    )
+    return res
+    # TODO: isn't the pattern match a bit cheesy ... shouldn't it be ^+n+' ('    and    ') '+n+$ ???
+
+

Select edges in a Cytoscape Network connecting the selected nodes, including self loops connecting single nodes.

+

Any edges selected beforehand are deselected before any new edges are selected

+

Args

+
+
network : SUID or str or None
+
Name or SUID of a network. Default is the +"current" network active in Cytoscape.
+
base_url : str
+
Ignore unless you need to specify a custom domain, +port or version to connect to the CyREST API. Default is http://127.0.0.1:1234 +and the latest version of the CyREST API supported by this version of py4cytoscape.
+
+

Returns

+
+
dict
+
{'nodes': [node list], 'edges': [edge list]} or None if no selected nodes
+
+

Raises

+
+
CyError
+
if network name or SUID doesn't exist
+
requests.exceptions.RequestException
+
if can't connect to Cytoscape or Cytoscape returns an error
+
+

Examples

+
>>> select_edges_connecting_selected_nodes()
+None
+>>> select_edges_connecting_selected_nodes(network='My Network')
+{'nodes': [103990, 103991, ...], 'edges': [104432, 104431, ...]}
+>>> select_edges_connecting_selected_nodes(network=52)
+{'nodes': [103990, 103991, ...], 'edges': [104432, 104431, ...]}
+
+

Note

+

In the return value node list is list of all selected nodes, and +edge list is the SUIDs of selected edges – dict is None if no nodes were selected or there were no newly +created edges

+
+
+
+
+
+
+ +
+ + + diff --git a/docs/lang_main/render/index.html b/docs/lang_main/render/index.html new file mode 100644 index 0000000..a3588dd --- /dev/null +++ b/docs/lang_main/render/index.html @@ -0,0 +1,83 @@ + + + + + + +lang_main.render API documentation + + + + + + + + + + + +
+ + +
+ + + diff --git a/docs/lang_main/search.html b/docs/lang_main/search.html new file mode 100644 index 0000000..f2f18ce --- /dev/null +++ b/docs/lang_main/search.html @@ -0,0 +1,261 @@ + + + + + + +lang_main.search API documentation + + + + + + + + + + + +
+
+
+

Module lang_main.search

+
+
+
+
+
+
+
+
+

Functions

+
+
+def search_base_path(starting_path: pathlib.Path, stop_folder_name: str | None = None) ‑> pathlib.Path | None +
+
+
+ +Expand source code + +
def search_base_path(
+    starting_path: Path,
+    stop_folder_name: str | None = None,
+) -> Path | None:
+    """Iteratively searches the parent directories of the starting path
+    and look for folders matching the given name. If a match is encountered,
+    the parent path will be returned.
+
+    Example:
+    starting_path = path/to/start/folder
+    stop_folder_name = 'to'
+    returned path = 'path/'
+
+    Parameters
+    ----------
+    starting_path : Path
+        non-inclusive starting path
+    stop_folder_name : str, optional
+        name of the last folder in the directory tree to search, by default None
+
+    Returns
+    -------
+    Path | None
+        Path if corresponding base path was found, None otherwise
+    """
+    stop_folder_path: Path | None = None
+    base_path: Path | None = None
+    for search_path in starting_path.parents:
+        if stop_folder_name is not None and search_path.name == stop_folder_name:
+            # library is placed inside a whole python installation for deployment
+            # only look up to this folder
+            stop_folder_path = search_path
+            break
+
+    if stop_folder_path is not None:
+        base_path = stop_folder_path.parent
+
+    return base_path
+
+

Iteratively searches the parent directories of the starting path +and look for folders matching the given name. If a match is encountered, +the parent path will be returned.

+

Example: +starting_path = path/to/start/folder +stop_folder_name = 'to' +returned path = 'path/'

+

Parameters

+
+
starting_path : Path
+
non-inclusive starting path
+
stop_folder_name : str, optional
+
name of the last folder in the directory tree to search, by default None
+
+

Returns

+
+
Path | None
+
Path if corresponding base path was found, None otherwise
+
+
+
+def search_cwd(glob_pattern: str) ‑> pathlib.Path | None +
+
+
+ +Expand source code + +
def search_cwd(
+    glob_pattern: str,
+) -> Path | None:
+    """Searches the current working directory and looks for files
+    matching the glob pattern.
+    Returns the first match encountered.
+
+    Parameters
+    ----------
+    glob_pattern : str, optional
+        pattern to look for, first match will be returned
+
+    Returns
+    -------
+    Path | None
+        Path if corresponding object was found, None otherwise
+    """
+    path_found: Path | None = None
+    res = tuple(Path.cwd().glob(glob_pattern))
+    if res:
+        path_found = res[0]
+
+    return path_found
+
+

Searches the current working directory and looks for files +matching the glob pattern. +Returns the first match encountered.

+

Parameters

+
+
glob_pattern : str, optional
+
pattern to look for, first match will be returned
+
+

Returns

+
+
Path | None
+
Path if corresponding object was found, None otherwise
+
+
+
+def search_iterative(starting_path: pathlib.Path,
glob_pattern: str,
stop_folder_name: str | None = None) ‑> pathlib.Path | None
+
+
+
+ +Expand source code + +
def search_iterative(
+    starting_path: Path,
+    glob_pattern: str,
+    stop_folder_name: str | None = None,
+) -> Path | None:
+    """Iteratively searches the parent directories of the starting path
+    and look for files matching the glob pattern. The starting path is not
+    searched, only its parents. Therefore the starting path can also point
+    to a file. The folder in which it is placed in will be searched.
+    Returns the first match encountered.
+    The parent of the stop folder will be searched if it exists.
+
+    Parameters
+    ----------
+    starting_path : Path
+        non-inclusive starting path
+    glob_pattern : str, optional
+        pattern to look for, first match will be returned
+    stop_folder_name : str, optional
+        name of the last folder in the directory tree to search, by default None
+
+    Returns
+    -------
+    Path | None
+        Path if corresponding object was found, None otherwise
+    """
+    file_path: Path | None = None
+    stop_folder_reached: bool = False
+    for search_path in starting_path.parents:
+        res = tuple(search_path.glob(glob_pattern))
+        if res:
+            file_path = res[0]
+            break
+        elif stop_folder_reached:
+            break
+
+        if stop_folder_name is not None and search_path.name == stop_folder_name:
+            # library is placed inside a whole python installation for deployment
+            # if this folder is reached, only look up one parent above
+            stop_folder_reached = True
+
+    return file_path
+
+

Iteratively searches the parent directories of the starting path +and look for files matching the glob pattern. The starting path is not +searched, only its parents. Therefore the starting path can also point +to a file. The folder in which it is placed in will be searched. +Returns the first match encountered. +The parent of the stop folder will be searched if it exists.

+

Parameters

+
+
starting_path : Path
+
non-inclusive starting path
+
glob_pattern : str, optional
+
pattern to look for, first match will be returned
+
stop_folder_name : str, optional
+
name of the last folder in the directory tree to search, by default None
+
+

Returns

+
+
Path | None
+
Path if corresponding object was found, None otherwise
+
+
+
+
+
+
+
+ +
+ + + diff --git a/docs/lang_main/types.html b/docs/lang_main/types.html new file mode 100644 index 0000000..6de1573 --- /dev/null +++ b/docs/lang_main/types.html @@ -0,0 +1,10637 @@ + + + + + + +lang_main.types API documentation + + + + + + + + + + + +
+
+
+

Module lang_main.types

+
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class SpacyDoc +(...) +
+
+

Doc(Vocab vocab, words=None, spaces=None, user_data=None, *, tags=None, pos=None, morphs=None, lemmas=None, heads=None, deps=None, sent_starts=None, ents=None) +A sequence of Token objects. Access sentences and named entities, export +annotations to numpy arrays, losslessly serialize to compressed binary +strings. The Doc object holds an array of TokenC structs. The +Python-level Token and Span objects are views of this array, i.e. +they don't own the data themselves.

+
EXAMPLE:
+    Construction 1
+    >>> doc = nlp(u'Some text')
+
+    Construction 2
+    >>> from spacy.tokens import Doc
+    >>> doc = Doc(nlp.vocab, words=["hello", "world", "!"], spaces=[True, False, False])
+
+DOCS: <https://spacy.io/api/doc>
+
+

Create a Doc object.

+

vocab (Vocab): A vocabulary object, which must match any models you +want to use (e.g. tokenizer, parser, entity recognizer). +words (Optional[List[Union[str, int]]]): A list of unicode strings or +hash values to add to the document as words. If None, defaults to +empty list. +spaces (Optional[List[bool]]): A list of boolean values, of the same +length as words. True means that the word is followed by a space, +False means it is not. If None, defaults to [True]*len(words) +user_data (dict or None): Optional extra data to attach to the Doc. +tags (Optional[List[str]]): A list of unicode strings, of the same +length as words, to assign as token.tag. Defaults to None. +pos (Optional[List[str]]): A list of unicode strings, of the same +length as words, to assign as token.pos. Defaults to None. +morphs (Optional[List[str]]): A list of unicode strings, of the same +length as words, to assign as token.morph. Defaults to None. +lemmas (Optional[List[str]]): A list of unicode strings, of the same +length as words, to assign as token.lemma. Defaults to None. +heads (Optional[List[int]]): A list of values, of the same length as +words, to assign as heads. Head indices are the position of the +head in the doc. Defaults to None. +deps (Optional[List[str]]): A list of unicode strings, of the same +length as words, to assign as token.dep. Defaults to None. +sent_starts (Optional[List[Union[bool, int, None]]]): A list of values, +of the same length as words, to assign as token.is_sent_start. Will +be overridden by heads if heads is provided. Defaults to None. +ents (Optional[List[str]]): A list of unicode strings, of the same +length as words, as IOB tags to assign as token.ent_iob and +token.ent_type. Defaults to None.

+

DOCS: https://spacy.io/api/doc#init

+

Static methods

+
+
+def from_docs(...) +
+
+

Doc.from_docs(docs, ensure_whitespace=True, attrs=None, *, exclude=tuple()) +Concatenate multiple Doc objects to form a new one. Raises an error +if the Doc objects do not all share the same Vocab.

+
    docs (list): A list of Doc objects.
+    ensure_whitespace (bool): Insert a space between two adjacent docs
+        whenever the first doc does not end in whitespace.
+    attrs (list): Optional list of attribute ID ints or attribute name
+        strings.
+    exclude (Iterable[str]): Doc attributes to exclude. Supported
+        attributes: <code>spans</code>, <code>tensor</code>, <code>user\_data</code>.
+    RETURNS (Doc): A doc that contains the concatenated docs, or None if no
+        docs were given.
+
+    DOCS: <https://spacy.io/api/doc#from_docs>
+
+
+
+

Instance variables

+
+
var cats
+
+

cats: object

+
+
var doc
+
+
+
+
var ents
+
+

The named entities in the document. Returns a tuple of named entity +Span objects, if the entity recognizer has been applied.

+

RETURNS (tuple): Entities in the document, one Span per entity.

+

DOCS: https://spacy.io/api/doc#ents

+
+
var has_unknown_spaces
+
+

has_unknown_spaces: 'bool'

+
+
var has_vector
+
+

A boolean value indicating whether a word vector is associated with +the object.

+

RETURNS (bool): Whether a word vector is associated with the object.

+

DOCS: https://spacy.io/api/doc#has_vector

+
+
var is_nered
+
+
+
+
var is_parsed
+
+
+
+
var is_sentenced
+
+
+
+
var is_tagged
+
+
+
+
var lang
+
+

RETURNS (uint64): ID of the language of the doc's vocabulary.

+
+
var lang_
+
+

RETURNS (str): Language of the doc's vocabulary, e.g. 'en'.

+
+
var mem
+
+
+
+
var noun_chunks
+
+

Iterate over the base noun phrases in the document. Yields base +noun-phrase #[code Span] objects, if the language has a noun chunk iterator. +Raises a NotImplementedError otherwise.

+

A base noun phrase, or "NP chunk", is a noun +phrase that does not permit other NPs to be nested within it – so no +NP-level coordination, no prepositional phrases, and no relative +clauses.

+

YIELDS (Span): Noun chunks in the document.

+

DOCS: https://spacy.io/api/doc#noun_chunks

+
+
var noun_chunks_iterator
+
+

noun_chunks_iterator: object

+
+
var sentiment
+
+

sentiment: 'float'

+
+
var sents
+
+

Iterate over the sentences in the document. Yields sentence Span +objects. Sentence spans have no label.

+

YIELDS (Span): Sentences in the document.

+

DOCS: https://spacy.io/api/doc#sents

+
+
var spans
+
+
+
+
var tensor
+
+

tensor: object

+
+
var text
+
+

A unicode representation of the document text.

+

RETURNS (str): The original verbatim text of the document.

+
+
var text_with_ws
+
+

An alias of Doc.text, provided for duck-type compatibility with +Span and Token.

+

RETURNS (str): The original verbatim text of the document.

+
+
var user_data
+
+

user_data: object

+
+
var user_hooks
+
+

user_hooks: dict

+
+
var user_span_hooks
+
+

user_span_hooks: dict

+
+
var user_token_hooks
+
+

user_token_hooks: dict

+
+
var vector
+
+

A real-valued meaning representation. Defaults to an average of the +token vectors.

+

RETURNS (numpy.ndarray[ndim=1, dtype='float32']): A 1D numpy array +representing the document's semantics.

+

DOCS: https://spacy.io/api/doc#vector

+
+
var vector_norm
+
+

The L2 norm of the document's vector representation.

+

RETURNS (float): The L2 norm of the vector representation.

+

DOCS: https://spacy.io/api/doc#vector_norm

+
+
var vocab
+
+
+
+
+

Methods

+
+
+def char_span(...) +
+
+

Doc.char_span(self, int start_idx, int end_idx, label=0, kb_id=0, vector=None, alignment_mode='strict', span_id=0) +Create a Span object from the slice +doc.text[start_idx : end_idx]. Returns None if no valid Span can be +created.

+
    doc (Doc): The parent document.
+    start_idx (int): The index of the first character of the span.
+    end_idx (int): The index of the first character after the span.
+    label (Union[int, str]): A label to attach to the Span, e.g. for
+        named entities.
+    kb_id (Union[int, str]):  An ID from a KB to capture the meaning of a
+        named entity.
+    vector (ndarray[ndim=1, dtype='float32']): A meaning representation of
+        the span.
+    alignment_mode (str): How character indices are aligned to token
+        boundaries. Options: "strict" (character indices must be aligned
+        with token boundaries), "contract" (span of all tokens completely
+        within the character span), "expand" (span of all tokens at least
+        partially covered by the character span). Defaults to "strict".
+    span_id (Union[int, str]): An identifier to associate with the span.
+    RETURNS (Span): The newly constructed object.
+
+    DOCS: <https://spacy.io/api/doc#char_span>
+
+
+
+def copy(...) +
+
+

Doc.copy(self)

+
+
+def count_by(...) +
+
+

Doc.count_by(self, attr_id_t attr_id, exclude=None, counts=None) +Count the frequencies of a given attribute. Produces a dict of +{attribute (int): count (ints)} frequencies, keyed by the values of +the given attribute ID.

+
    attr_id (int): The attribute ID to key the counts.
+    RETURNS (dict): A dictionary mapping attributes to integer counts.
+
+    DOCS: <https://spacy.io/api/doc#count_by>
+
+
+
+def extend_tensor(...) +
+
+

Doc.extend_tensor(self, tensor) +Concatenate a new tensor onto the doc.tensor object.

+
    The doc.tensor attribute holds dense feature vectors
+    computed by the models in the pipeline. Let's say a
+    document with 30 words has a tensor with 128 dimensions
+    per word. doc.tensor.shape will be (30, 128). After
+    calling doc.extend_tensor with an array of shape (30, 64),
+    doc.tensor == (30, 192).
+
+
+
+def from_array(...) +
+
+

Doc.from_array(self, attrs, array) +Load attributes from a numpy array. Write to a Doc object, from an +(M, N) array of attributes.

+
    attrs (list) A list of attribute ID ints.
+    array (numpy.ndarray[ndim=2, dtype='int32']): The attribute values.
+    RETURNS (Doc): Itself.
+
+    DOCS: <https://spacy.io/api/doc#from_array>
+
+
+
+def from_bytes(...) +
+
+

Doc.from_bytes(self, bytes_data, *, exclude=tuple()) +Deserialize, i.e. import the document contents from a binary string.

+
    data (bytes): The string to load from.
+    exclude (Iterable[str]): String names of serialization fields to exclude.
+    RETURNS (Doc): Itself.
+
+    DOCS: <https://spacy.io/api/doc#from_bytes>
+
+
+
+def from_dict(...) +
+
+

Doc.from_dict(self, msg, *, exclude=tuple()) +Deserialize the document contents from a dictionary representation.

+
    msg (Dict[str, Any]): The dictionary to load from.
+    exclude (Iterable[str]): String names of serialization fields to exclude.
+    RETURNS (Doc): Itself.
+
+
+
+def from_disk(...) +
+
+

Doc.from_disk(self, path, *, exclude=tuple()) +Loads state from a directory. Modifies the object in place and +returns it.

+
    path (str / Path): A path to a directory. Paths may be either
+        strings or <code>Path</code>-like objects.
+    exclude (Iterable[str]): String names of serialization fields to exclude.
+    RETURNS (Doc): The modified <code><a title="lang_main.types.Doc" href="#lang_main.types.Doc">Doc</a></code> object.
+
+    DOCS: <https://spacy.io/api/doc#from_disk>
+
+
+
+def from_json(...) +
+
+

Doc.from_json(self, doc_json, *, validate=False) +Convert a JSON document generated by Doc.to_json() to a Doc.

+
    doc_json (Dict): JSON representation of doc object to load.
+    validate (bool): Whether to validate <code>doc\_json</code> against the expected schema.
+        Defaults to False.
+    RETURNS (Doc): A doc instance corresponding to the specified JSON representation.
+
+
+
+def get_extension(...) +
+
+

Doc.get_extension(type cls, name) +Look up a previously registered extension by name.

+
    name (str): Name of the extension.
+    RETURNS (tuple): A <code>(default, method, getter, setter)</code> tuple.
+
+    DOCS: <https://spacy.io/api/doc#get_extension>
+
+
+
+def get_lca_matrix(...) +
+
+

Doc.get_lca_matrix(self) +Calculates a matrix of Lowest Common Ancestors (LCA) for a given +Doc, where LCA[i, j] is the index of the lowest common ancestor among +token i and j.

+
    RETURNS (np.array[ndim=2, dtype=numpy.int32]): LCA matrix with shape
+        (n, n), where n = len(self).
+
+    DOCS: <https://spacy.io/api/doc#get_lca_matrix>
+
+
+
+def has_annotation(...) +
+
+

Doc.has_annotation(self, attr, *, require_complete=False) +Check whether the doc contains annotation on a token attribute.

+
    attr (Union[int, str]): The attribute string name or int ID.
+    require_complete (bool): Whether to check that the attribute is set on
+        every token in the doc.
+    RETURNS (bool): Whether annotation is present.
+
+    DOCS: <https://spacy.io/api/doc#has_annotation>
+
+
+
+def has_extension(...) +
+
+

Doc.has_extension(type cls, name) +Check whether an extension has been registered.

+
    name (str): Name of the extension.
+    RETURNS (bool): Whether the extension has been registered.
+
+    DOCS: <https://spacy.io/api/doc#has_extension>
+
+
+
+def remove_extension(...) +
+
+

Doc.remove_extension(type cls, name) +Remove a previously registered extension.

+
    name (str): Name of the extension.
+    RETURNS (tuple): A <code>(default, method, getter, setter)</code> tuple of the
+        removed extension.
+
+    DOCS: <https://spacy.io/api/doc#remove_extension>
+
+
+
+def retokenize(...) +
+
+

Doc.retokenize(self) +Context manager to handle retokenization of the Doc. +Modifications to the Doc's tokenization are stored, and then +made all at once when the context manager exits. This is +much more efficient, and less error-prone.

+
    All views of the Doc (Span and Token) created before the
+    retokenization are invalidated, although they may accidentally
+    continue to work.
+
+    DOCS: <https://spacy.io/api/doc#retokenize>
+    USAGE: <https://spacy.io/usage/linguistic-features#retokenization>
+
+
+
+def set_ents(...) +
+
+

Doc.set_ents(self, entities, *, blocked=None, missing=None, outside=None, default=SetEntsDefault.outside) +Set entity annotation.

+
    entities (List[Span]): Spans with labels to set as entities.
+    blocked (Optional[List[Span]]): Spans to set as 'blocked' (never an
+        entity) for spacy's built-in NER component. Other components may
+        ignore this setting.
+    missing (Optional[List[Span]]): Spans with missing/unknown entity
+        information.
+    outside (Optional[List[Span]]): Spans outside of entities (O in IOB).
+    default (str): How to set entity annotation for tokens outside of any
+        provided spans. Options: "blocked", "missing", "outside" and
+        "unmodified" (preserve current state). Defaults to "outside".
+
+
+
+def set_extension(...) +
+
+

Doc.set_extension(type cls, name, **kwargs) +Define a custom attribute which becomes available as Doc._.

+
    name (str): Name of the attribute to set.
+    default: Optional default value of the attribute.
+    getter (callable): Optional getter function.
+    setter (callable): Optional setter function.
+    method (callable): Optional method for method extension.
+    force (bool): Force overwriting existing attribute.
+
+    DOCS: <https://spacy.io/api/doc#set_extension>
+    USAGE: <https://spacy.io/usage/processing-pipelines#custom-components-attributes>
+
+
+
+def similarity(...) +
+
+

Doc.similarity(self, other) +Make a semantic similarity estimate. The default estimate is cosine +similarity using an average of word vectors.

+
    other (object): The object to compare with. By default, accepts <code><a title="lang_main.types.Doc" href="#lang_main.types.Doc">Doc</a></code>,
+        <code>Span</code>, <code><a title="lang_main.types.Token" href="#lang_main.types.Token">Token</a></code> and <code>Lexeme</code> objects.
+    RETURNS (float): A scalar similarity score. Higher is more similar.
+
+    DOCS: <https://spacy.io/api/doc#similarity>
+
+
+
+def to_array(...) +
+
+

Doc.to_array(self, py_attr_ids) -> ndarray +Export given token attributes to a numpy ndarray. +If attr_ids is a sequence of M attributes, the output array will be +of shape (N, M), where N is the length of the Doc (in tokens). If +attr_ids is a single attribute, the output shape will be (N,). You +can specify attributes by integer ID (e.g. spacy.attrs.LEMMA) or +string name (e.g. 'LEMMA' or 'lemma').

+
    py_attr_ids (list[]): A list of attributes (int IDs or string names).
+    RETURNS (numpy.ndarray[long, ndim=2]): A feature matrix, with one row
+        per word, and one column per attribute indicated in the input
+        <code>attr\_ids</code>.
+
+    EXAMPLE:
+        >>> from spacy.attrs import LOWER, POS, ENT_TYPE, IS_ALPHA
+        >>> doc = nlp(text)
+        >>> # All strings mapped to integers, for easy export to numpy
+        >>> np_array = doc.to_array([LOWER, POS, ENT_TYPE, IS_ALPHA])
+
+
+
+def to_bytes(...) +
+
+

Doc.to_bytes(self, *, exclude=tuple()) +Serialize, i.e. export the document contents to a binary string.

+
    exclude (Iterable[str]): String names of serialization fields to exclude.
+    RETURNS (bytes): A losslessly serialized copy of the <code><a title="lang_main.types.Doc" href="#lang_main.types.Doc">Doc</a></code>, including
+        all annotations.
+
+    DOCS: <https://spacy.io/api/doc#to_bytes>
+
+
+
+def to_dict(...) +
+
+

Doc.to_dict(self, *, exclude=tuple()) +Export the document contents to a dictionary for serialization.

+
    exclude (Iterable[str]): String names of serialization fields to exclude.
+    RETURNS (Dict[str, Any]): A dictionary representation of the <code><a title="lang_main.types.Doc" href="#lang_main.types.Doc">Doc</a></code>
+
+
+
+def to_disk(...) +
+
+

Doc.to_disk(self, path, *, exclude=tuple()) +Save the current state to a directory.

+
    path (str / Path): A path to a directory, which will be created if
+        it doesn't exist. Paths may be either strings or Path-like objects.
+    exclude (Iterable[str]): String names of serialization fields to exclude.
+
+    DOCS: <https://spacy.io/api/doc#to_disk>
+
+
+
+def to_json(...) +
+
+

Doc.to_json(self, underscore=None) +Convert a Doc to JSON.

+
    underscore (list): Optional list of string names of custom doc._.
+    attributes. Attribute values need to be JSON-serializable. Values will
+    be added to an "_" key in the data, e.g. "_": {"foo": "bar"}.
+    RETURNS (dict): The data in JSON format.
+
+
+
+def to_utf8_array(...) +
+
+

Doc.to_utf8_array(self, int nr_char=-1) +Encode word strings to utf8, and export to a fixed-width array +of characters. Characters are placed into the array in the order: +0, -1, 1, -2, etc +For example, if the array is sliced array[:, :8], the array will +contain the first 4 characters and last 4 characters of each word — +with the middle characters clipped out. The value 255 is used as a pad +value.

+
+
+
+
+class SpacyModel +(vocab: spacy.vocab.Vocab | bool = True,
*,
max_length: int = 1000000,
meta: Dict[str, Any] = {},
create_tokenizer: Callable[[ForwardRef('Language')], Callable[[str], spacy.tokens.doc.Doc]] | None = None,
create_vectors: Callable[[ForwardRef('Vocab')], spacy.vectors.BaseVectors] | None = None,
batch_size: int = 1000,
**kwargs)
+
+
+
+ +Expand source code + +
class Language:
+    """A text-processing pipeline. Usually you'll load this once per process,
+    and pass the instance around your application.
+
+    Defaults (class): Settings, data and factory methods for creating the `nlp`
+        object and processing pipeline.
+    lang (str): IETF language code, such as 'en'.
+
+    DOCS: https://spacy.io/api/language
+    """
+
+    Defaults = BaseDefaults
+    lang: Optional[str] = None
+    default_config = DEFAULT_CONFIG
+
+    factories = SimpleFrozenDict(error=Errors.E957)
+    _factory_meta: Dict[str, "FactoryMeta"] = {}  # meta by factory
+
+    def __init__(
+        self,
+        vocab: Union[Vocab, bool] = True,
+        *,
+        max_length: int = 10**6,
+        meta: Dict[str, Any] = {},
+        create_tokenizer: Optional[Callable[["Language"], Callable[[str], Doc]]] = None,
+        create_vectors: Optional[Callable[["Vocab"], BaseVectors]] = None,
+        batch_size: int = 1000,
+        **kwargs,
+    ) -> None:
+        """Initialise a Language object.
+
+        vocab (Vocab): A `Vocab` object. If `True`, a vocab is created.
+        meta (dict): Custom meta data for the Language class. Is written to by
+            models to add model meta data.
+        max_length (int): Maximum number of characters in a single text. The
+            current models may run out memory on extremely long texts, due to
+            large internal allocations. You should segment these texts into
+            meaningful units, e.g. paragraphs, subsections etc, before passing
+            them to spaCy. Default maximum length is 1,000,000 charas (1mb). As
+            a rule of thumb, if all pipeline components are enabled, spaCy's
+            default models currently requires roughly 1GB of temporary memory per
+            100,000 characters in one text.
+        create_tokenizer (Callable): Function that takes the nlp object and
+            returns a tokenizer.
+        batch_size (int): Default batch size for pipe and evaluate.
+
+        DOCS: https://spacy.io/api/language#init
+        """
+        # We're only calling this to import all factories provided via entry
+        # points. The factory decorator applied to these functions takes care
+        # of the rest.
+        util.registry._entry_point_factories.get_all()
+
+        self._config = DEFAULT_CONFIG.merge(self.default_config)
+        self._meta = dict(meta)
+        self._path = None
+        self._optimizer: Optional[Optimizer] = None
+        # Component meta and configs are only needed on the instance
+        self._pipe_meta: Dict[str, "FactoryMeta"] = {}  # meta by component
+        self._pipe_configs: Dict[str, Config] = {}  # config by component
+
+        if not isinstance(vocab, Vocab) and vocab is not True:
+            raise ValueError(Errors.E918.format(vocab=vocab, vocab_type=type(Vocab)))
+        if vocab is True:
+            vectors_name = meta.get("vectors", {}).get("name")
+            vocab = create_vocab(self.lang, self.Defaults, vectors_name=vectors_name)
+            if not create_vectors:
+                vectors_cfg = {"vectors": self._config["nlp"]["vectors"]}
+                create_vectors = registry.resolve(vectors_cfg)["vectors"]
+            vocab.vectors = create_vectors(vocab)
+        else:
+            if (self.lang and vocab.lang) and (self.lang != vocab.lang):
+                raise ValueError(Errors.E150.format(nlp=self.lang, vocab=vocab.lang))
+        self.vocab: Vocab = vocab
+        if self.lang is None:
+            self.lang = self.vocab.lang
+        self._components: List[Tuple[str, PipeCallable]] = []
+        self._disabled: Set[str] = set()
+        self.max_length = max_length
+        # Create the default tokenizer from the default config
+        if not create_tokenizer:
+            tokenizer_cfg = {"tokenizer": self._config["nlp"]["tokenizer"]}
+            create_tokenizer = registry.resolve(tokenizer_cfg)["tokenizer"]
+        self.tokenizer = create_tokenizer(self)
+        self.batch_size = batch_size
+        self.default_error_handler = raise_error
+
+    def __init_subclass__(cls, **kwargs):
+        super().__init_subclass__(**kwargs)
+        cls.default_config = DEFAULT_CONFIG.merge(cls.Defaults.config)
+        cls.default_config["nlp"]["lang"] = cls.lang
+
+    @property
+    def path(self):
+        return self._path
+
+    @property
+    def meta(self) -> Dict[str, Any]:
+        """Custom meta data of the language class. If a model is loaded, this
+        includes details from the model's meta.json.
+
+        RETURNS (Dict[str, Any]): The meta.
+
+        DOCS: https://spacy.io/api/language#meta
+        """
+        spacy_version = util.get_minor_version_range(about.__version__)
+        if self.vocab.lang:
+            self._meta.setdefault("lang", self.vocab.lang)
+        else:
+            self._meta.setdefault("lang", self.lang)
+        self._meta.setdefault("name", "pipeline")
+        self._meta.setdefault("version", "0.0.0")
+        self._meta.setdefault("spacy_version", spacy_version)
+        self._meta.setdefault("description", "")
+        self._meta.setdefault("author", "")
+        self._meta.setdefault("email", "")
+        self._meta.setdefault("url", "")
+        self._meta.setdefault("license", "")
+        self._meta.setdefault("spacy_git_version", GIT_VERSION)
+        self._meta["vectors"] = {
+            "width": self.vocab.vectors_length,
+            "vectors": len(self.vocab.vectors),
+            "keys": self.vocab.vectors.n_keys,
+            "name": self.vocab.vectors.name,
+            "mode": self.vocab.vectors.mode,
+        }
+        self._meta["labels"] = dict(self.pipe_labels)
+        # TODO: Adding this back to prevent breaking people's code etc., but
+        # we should consider removing it
+        self._meta["pipeline"] = list(self.pipe_names)
+        self._meta["components"] = list(self.component_names)
+        self._meta["disabled"] = list(self.disabled)
+        return self._meta
+
+    @meta.setter
+    def meta(self, value: Dict[str, Any]) -> None:
+        self._meta = value
+
+    @property
+    def config(self) -> Config:
+        """Trainable config for the current language instance. Includes the
+        current pipeline components, as well as default training config.
+
+        RETURNS (thinc.api.Config): The config.
+
+        DOCS: https://spacy.io/api/language#config
+        """
+        self._config.setdefault("nlp", {})
+        self._config.setdefault("training", {})
+        self._config["nlp"]["lang"] = self.lang
+        # We're storing the filled config for each pipeline component and so
+        # we can populate the config again later
+        pipeline = {}
+        score_weights = []
+        for pipe_name in self.component_names:
+            pipe_meta = self.get_pipe_meta(pipe_name)
+            pipe_config = self.get_pipe_config(pipe_name)
+            pipeline[pipe_name] = {"factory": pipe_meta.factory, **pipe_config}
+            if pipe_meta.default_score_weights:
+                score_weights.append(pipe_meta.default_score_weights)
+        self._config["nlp"]["pipeline"] = list(self.component_names)
+        self._config["nlp"]["disabled"] = list(self.disabled)
+        self._config["components"] = pipeline
+        # We're merging the existing score weights back into the combined
+        # weights to make sure we're preserving custom settings in the config
+        # but also reflect updates (e.g. new components added)
+        prev_weights = self._config["training"].get("score_weights", {})
+        combined_score_weights = combine_score_weights(score_weights, prev_weights)
+        self._config["training"]["score_weights"] = combined_score_weights
+        if not srsly.is_json_serializable(self._config):
+            raise ValueError(Errors.E961.format(config=self._config))
+        return self._config
+
+    @config.setter
+    def config(self, value: Config) -> None:
+        self._config = value
+
+    @property
+    def disabled(self) -> List[str]:
+        """Get the names of all disabled components.
+
+        RETURNS (List[str]): The disabled components.
+        """
+        # Make sure the disabled components are returned in the order they
+        # appear in the pipeline (which isn't guaranteed by the set)
+        names = [name for name, _ in self._components if name in self._disabled]
+        return SimpleFrozenList(names, error=Errors.E926.format(attr="disabled"))
+
+    @property
+    def factory_names(self) -> List[str]:
+        """Get names of all available factories.
+
+        RETURNS (List[str]): The factory names.
+        """
+        names = list(self.factories.keys())
+        return SimpleFrozenList(names)
+
+    @property
+    def components(self) -> List[Tuple[str, PipeCallable]]:
+        """Get all (name, component) tuples in the pipeline, including the
+        currently disabled components.
+        """
+        return SimpleFrozenList(
+            self._components, error=Errors.E926.format(attr="components")
+        )
+
+    @property
+    def component_names(self) -> List[str]:
+        """Get the names of the available pipeline components. Includes all
+        active and inactive pipeline components.
+
+        RETURNS (List[str]): List of component name strings, in order.
+        """
+        names = [pipe_name for pipe_name, _ in self._components]
+        return SimpleFrozenList(names, error=Errors.E926.format(attr="component_names"))
+
+    @property
+    def pipeline(self) -> List[Tuple[str, PipeCallable]]:
+        """The processing pipeline consisting of (name, component) tuples. The
+        components are called on the Doc in order as it passes through the
+        pipeline.
+
+        RETURNS (List[Tuple[str, Callable[[Doc], Doc]]]): The pipeline.
+        """
+        pipes = [(n, p) for n, p in self._components if n not in self._disabled]
+        return SimpleFrozenList(pipes, error=Errors.E926.format(attr="pipeline"))
+
+    @property
+    def pipe_names(self) -> List[str]:
+        """Get names of available active pipeline components.
+
+        RETURNS (List[str]): List of component name strings, in order.
+        """
+        names = [pipe_name for pipe_name, _ in self.pipeline]
+        return SimpleFrozenList(names, error=Errors.E926.format(attr="pipe_names"))
+
+    @property
+    def pipe_factories(self) -> Dict[str, str]:
+        """Get the component factories for the available pipeline components.
+
+        RETURNS (Dict[str, str]): Factory names, keyed by component names.
+        """
+        factories = {}
+        for pipe_name, pipe in self._components:
+            factories[pipe_name] = self.get_pipe_meta(pipe_name).factory
+        return SimpleFrozenDict(factories)
+
+    @property
+    def pipe_labels(self) -> Dict[str, List[str]]:
+        """Get the labels set by the pipeline components, if available (if
+        the component exposes a labels property and the labels are not
+        hidden).
+
+        RETURNS (Dict[str, List[str]]): Labels keyed by component name.
+        """
+        labels = {}
+        for name, pipe in self._components:
+            if hasattr(pipe, "hide_labels") and pipe.hide_labels is True:
+                continue
+            if hasattr(pipe, "labels"):
+                labels[name] = list(pipe.labels)
+        return SimpleFrozenDict(labels)
+
+    @classmethod
+    def has_factory(cls, name: str) -> bool:
+        """RETURNS (bool): Whether a factory of that name is registered."""
+        internal_name = cls.get_factory_name(name)
+        return name in registry.factories or internal_name in registry.factories
+
+    @classmethod
+    def get_factory_name(cls, name: str) -> str:
+        """Get the internal factory name based on the language subclass.
+
+        name (str): The factory name.
+        RETURNS (str): The internal factory name.
+        """
+        if cls.lang is None:
+            return name
+        return f"{cls.lang}.{name}"
+
+    @classmethod
+    def get_factory_meta(cls, name: str) -> "FactoryMeta":
+        """Get the meta information for a given factory name.
+
+        name (str): The component factory name.
+        RETURNS (FactoryMeta): The meta for the given factory name.
+        """
+        internal_name = cls.get_factory_name(name)
+        if internal_name in cls._factory_meta:
+            return cls._factory_meta[internal_name]
+        if name in cls._factory_meta:
+            return cls._factory_meta[name]
+        raise ValueError(Errors.E967.format(meta="factory", name=name))
+
+    @classmethod
+    def set_factory_meta(cls, name: str, value: "FactoryMeta") -> None:
+        """Set the meta information for a given factory name.
+
+        name (str): The component factory name.
+        value (FactoryMeta): The meta to set.
+        """
+        cls._factory_meta[cls.get_factory_name(name)] = value
+
+    def get_pipe_meta(self, name: str) -> "FactoryMeta":
+        """Get the meta information for a given component name.
+
+        name (str): The component name.
+        RETURNS (FactoryMeta): The meta for the given component name.
+        """
+        if name not in self._pipe_meta:
+            raise ValueError(Errors.E967.format(meta="component", name=name))
+        return self._pipe_meta[name]
+
+    def get_pipe_config(self, name: str) -> Config:
+        """Get the config used to create a pipeline component.
+
+        name (str): The component name.
+        RETURNS (Config): The config used to create the pipeline component.
+        """
+        if name not in self._pipe_configs:
+            raise ValueError(Errors.E960.format(name=name))
+        pipe_config = self._pipe_configs[name]
+        return pipe_config
+
+    @classmethod
+    def factory(
+        cls,
+        name: str,
+        *,
+        default_config: Dict[str, Any] = SimpleFrozenDict(),
+        assigns: Iterable[str] = SimpleFrozenList(),
+        requires: Iterable[str] = SimpleFrozenList(),
+        retokenizes: bool = False,
+        default_score_weights: Dict[str, Optional[float]] = SimpleFrozenDict(),
+        func: Optional[Callable] = None,
+    ) -> Callable:
+        """Register a new pipeline component factory. Can be used as a decorator
+        on a function or classmethod, or called as a function with the factory
+        provided as the func keyword argument. To create a component and add
+        it to the pipeline, you can use nlp.add_pipe(name).
+
+        name (str): The name of the component factory.
+        default_config (Dict[str, Any]): Default configuration, describing the
+            default values of the factory arguments.
+        assigns (Iterable[str]): Doc/Token attributes assigned by this component,
+            e.g. "token.ent_id". Used for pipeline analysis.
+        requires (Iterable[str]): Doc/Token attributes required by this component,
+            e.g. "token.ent_id". Used for pipeline analysis.
+        retokenizes (bool): Whether the component changes the tokenization.
+            Used for pipeline analysis.
+        default_score_weights (Dict[str, Optional[float]]): The scores to report during
+            training, and their default weight towards the final score used to
+            select the best model. Weights should sum to 1.0 per component and
+            will be combined and normalized for the whole pipeline. If None,
+            the score won't be shown in the logs or be weighted.
+        func (Optional[Callable]): Factory function if not used as a decorator.
+
+        DOCS: https://spacy.io/api/language#factory
+        """
+        if not isinstance(name, str):
+            raise ValueError(Errors.E963.format(decorator="factory"))
+        if "." in name:
+            raise ValueError(Errors.E853.format(name=name))
+        if not isinstance(default_config, dict):
+            err = Errors.E962.format(
+                style="default config", name=name, cfg_type=type(default_config)
+            )
+            raise ValueError(err)
+
+        def add_factory(factory_func: Callable) -> Callable:
+            internal_name = cls.get_factory_name(name)
+            if internal_name in registry.factories:
+                # We only check for the internal name here – it's okay if it's a
+                # subclass and the base class has a factory of the same name. We
+                # also only raise if the function is different to prevent raising
+                # if module is reloaded.
+                existing_func = registry.factories.get(internal_name)
+                if not util.is_same_func(factory_func, existing_func):
+                    err = Errors.E004.format(
+                        name=name, func=existing_func, new_func=factory_func
+                    )
+                    raise ValueError(err)
+
+            arg_names = util.get_arg_names(factory_func)
+            if "nlp" not in arg_names or "name" not in arg_names:
+                raise ValueError(Errors.E964.format(name=name))
+            # Officially register the factory so we can later call
+            # registry.resolve and refer to it in the config as
+            # @factories = "spacy.Language.xyz". We use the class name here so
+            # different classes can have different factories.
+            registry.factories.register(internal_name, func=factory_func)
+            factory_meta = FactoryMeta(
+                factory=name,
+                default_config=default_config,
+                assigns=validate_attrs(assigns),
+                requires=validate_attrs(requires),
+                scores=list(default_score_weights.keys()),
+                default_score_weights=default_score_weights,
+                retokenizes=retokenizes,
+            )
+            cls.set_factory_meta(name, factory_meta)
+            # We're overwriting the class attr with a frozen dict to handle
+            # backwards-compat (writing to Language.factories directly). This
+            # wouldn't work with an instance property and just produce a
+            # confusing error – here we can show a custom error
+            cls.factories = SimpleFrozenDict(
+                registry.factories.get_all(), error=Errors.E957
+            )
+            return factory_func
+
+        if func is not None:  # Support non-decorator use cases
+            return add_factory(func)
+        return add_factory
+
+    @classmethod
+    def component(
+        cls,
+        name: str,
+        *,
+        assigns: Iterable[str] = SimpleFrozenList(),
+        requires: Iterable[str] = SimpleFrozenList(),
+        retokenizes: bool = False,
+        func: Optional[PipeCallable] = None,
+    ) -> Callable[..., Any]:
+        """Register a new pipeline component. Can be used for stateless function
+        components that don't require a separate factory. Can be used as a
+        decorator on a function or classmethod, or called as a function with the
+        factory provided as the func keyword argument. To create a component and
+        add it to the pipeline, you can use nlp.add_pipe(name).
+
+        name (str): The name of the component factory.
+        assigns (Iterable[str]): Doc/Token attributes assigned by this component,
+            e.g. "token.ent_id". Used for pipeline analysis.
+        requires (Iterable[str]): Doc/Token attributes required by this component,
+            e.g. "token.ent_id". Used for pipeline analysis.
+        retokenizes (bool): Whether the component changes the tokenization.
+            Used for pipeline analysis.
+        func (Optional[Callable[[Doc], Doc]): Factory function if not used as a decorator.
+
+        DOCS: https://spacy.io/api/language#component
+        """
+        if name is not None:
+            if not isinstance(name, str):
+                raise ValueError(Errors.E963.format(decorator="component"))
+            if "." in name:
+                raise ValueError(Errors.E853.format(name=name))
+        component_name = name if name is not None else util.get_object_name(func)
+
+        def add_component(component_func: PipeCallable) -> Callable:
+            if isinstance(func, type):  # function is a class
+                raise ValueError(Errors.E965.format(name=component_name))
+
+            def factory_func(nlp, name: str) -> PipeCallable:
+                return component_func
+
+            internal_name = cls.get_factory_name(name)
+            if internal_name in registry.factories:
+                # We only check for the internal name here – it's okay if it's a
+                # subclass and the base class has a factory of the same name. We
+                # also only raise if the function is different to prevent raising
+                # if module is reloaded. It's hacky, but we need to check the
+                # existing functure for a closure and whether that's identical
+                # to the component function (because factory_func created above
+                # will always be different, even for the same function)
+                existing_func = registry.factories.get(internal_name)
+                closure = existing_func.__closure__
+                wrapped = [c.cell_contents for c in closure][0] if closure else None
+                if util.is_same_func(wrapped, component_func):
+                    factory_func = existing_func  # noqa: F811
+
+            cls.factory(
+                component_name,
+                assigns=assigns,
+                requires=requires,
+                retokenizes=retokenizes,
+                func=factory_func,
+            )
+            return component_func
+
+        if func is not None:  # Support non-decorator use cases
+            return add_component(func)
+        return add_component
+
+    def analyze_pipes(
+        self,
+        *,
+        keys: List[str] = ["assigns", "requires", "scores", "retokenizes"],
+        pretty: bool = False,
+    ) -> Optional[Dict[str, Any]]:
+        """Analyze the current pipeline components, print a summary of what
+        they assign or require and check that all requirements are met.
+
+        keys (List[str]): The meta values to display in the table. Corresponds
+            to values in FactoryMeta, defined by @Language.factory decorator.
+        pretty (bool): Pretty-print the results.
+        RETURNS (dict): The data.
+        """
+        analysis = analyze_pipes(self, keys=keys)
+        if pretty:
+            print_pipe_analysis(analysis, keys=keys)
+        return analysis
+
+    def get_pipe(self, name: str) -> PipeCallable:
+        """Get a pipeline component for a given component name.
+
+        name (str): Name of pipeline component to get.
+        RETURNS (callable): The pipeline component.
+
+        DOCS: https://spacy.io/api/language#get_pipe
+        """
+        for pipe_name, component in self._components:
+            if pipe_name == name:
+                return component
+        raise KeyError(Errors.E001.format(name=name, opts=self.component_names))
+
+    def create_pipe(
+        self,
+        factory_name: str,
+        name: Optional[str] = None,
+        *,
+        config: Dict[str, Any] = SimpleFrozenDict(),
+        raw_config: Optional[Config] = None,
+        validate: bool = True,
+    ) -> PipeCallable:
+        """Create a pipeline component. Mostly used internally. To create and
+        add a component to the pipeline, you can use nlp.add_pipe.
+
+        factory_name (str): Name of component factory.
+        name (Optional[str]): Optional name to assign to component instance.
+            Defaults to factory name if not set.
+        config (Dict[str, Any]): Config parameters to use for this component.
+            Will be merged with default config, if available.
+        raw_config (Optional[Config]): Internals: the non-interpolated config.
+        validate (bool): Whether to validate the component config against the
+            arguments and types expected by the factory.
+        RETURNS (Callable[[Doc], Doc]): The pipeline component.
+
+        DOCS: https://spacy.io/api/language#create_pipe
+        """
+        name = name if name is not None else factory_name
+        if not isinstance(config, dict):
+            err = Errors.E962.format(style="config", name=name, cfg_type=type(config))
+            raise ValueError(err)
+        if not srsly.is_json_serializable(config):
+            raise ValueError(Errors.E961.format(config=config))
+        if not self.has_factory(factory_name):
+            err = Errors.E002.format(
+                name=factory_name,
+                opts=", ".join(self.factory_names),
+                method="create_pipe",
+                lang=util.get_object_name(self),
+                lang_code=self.lang,
+            )
+            raise ValueError(err)
+        pipe_meta = self.get_factory_meta(factory_name)
+        # This is unideal, but the alternative would mean you always need to
+        # specify the full config settings, which is not really viable.
+        if pipe_meta.default_config:
+            config = Config(pipe_meta.default_config).merge(config)
+        internal_name = self.get_factory_name(factory_name)
+        # If the language-specific factory doesn't exist, try again with the
+        # not-specific name
+        if internal_name not in registry.factories:
+            internal_name = factory_name
+        # The name allows components to know their pipe name and use it in the
+        # losses etc. (even if multiple instances of the same factory are used)
+        config = {"nlp": self, "name": name, **config, "@factories": internal_name}
+        # We need to create a top-level key because Thinc doesn't allow resolving
+        # top-level references to registered functions. Also gives nicer errors.
+        cfg = {factory_name: config}
+        # We're calling the internal _fill here to avoid constructing the
+        # registered functions twice
+        resolved = registry.resolve(cfg, validate=validate)
+        filled = registry.fill({"cfg": cfg[factory_name]}, validate=validate)["cfg"]
+        filled = Config(filled)
+        filled["factory"] = factory_name
+        filled.pop("@factories", None)
+        # Remove the extra values we added because we don't want to keep passing
+        # them around, copying them etc.
+        filled.pop("nlp", None)
+        filled.pop("name", None)
+        # Merge the final filled config with the raw config (including non-
+        # interpolated variables)
+        if raw_config:
+            filled = filled.merge(raw_config)
+        self._pipe_configs[name] = filled
+        return resolved[factory_name]
+
+    def create_pipe_from_source(
+        self, source_name: str, source: "Language", *, name: str
+    ) -> Tuple[PipeCallable, str]:
+        """Create a pipeline component by copying it from an existing model.
+
+        source_name (str): Name of the component in the source pipeline.
+        source (Language): The source nlp object to copy from.
+        name (str): Optional alternative name to use in current pipeline.
+        RETURNS (Tuple[Callable[[Doc], Doc], str]): The component and its factory name.
+        """
+        # Check source type
+        if not isinstance(source, Language):
+            raise ValueError(Errors.E945.format(name=source_name, source=type(source)))
+        if self.vocab.vectors != source.vocab.vectors:
+            warnings.warn(Warnings.W113.format(name=source_name))
+        if source_name not in source.component_names:
+            raise KeyError(
+                Errors.E944.format(
+                    name=source_name,
+                    model=f"{source.meta['lang']}_{source.meta['name']}",
+                    opts=", ".join(source.component_names),
+                )
+            )
+        pipe = source.get_pipe(source_name)
+        # There is no actual solution here. Either the component has the right
+        # name for the source pipeline or the component has the right name for
+        # the current pipeline. This prioritizes the current pipeline.
+        if hasattr(pipe, "name"):
+            pipe.name = name
+        # Make sure the source config is interpolated so we don't end up with
+        # orphaned variables in our final config
+        source_config = source.config.interpolate()
+        pipe_config = util.copy_config(source_config["components"][source_name])
+        self._pipe_configs[name] = pipe_config
+        if self.vocab.strings != source.vocab.strings:
+            for s in source.vocab.strings:
+                self.vocab.strings.add(s)
+        return pipe, pipe_config["factory"]
+
+    def add_pipe(
+        self,
+        factory_name: str,
+        name: Optional[str] = None,
+        *,
+        before: Optional[Union[str, int]] = None,
+        after: Optional[Union[str, int]] = None,
+        first: Optional[bool] = None,
+        last: Optional[bool] = None,
+        source: Optional["Language"] = None,
+        config: Dict[str, Any] = SimpleFrozenDict(),
+        raw_config: Optional[Config] = None,
+        validate: bool = True,
+    ) -> PipeCallable:
+        """Add a component to the processing pipeline. Valid components are
+        callables that take a `Doc` object, modify it and return it. Only one
+        of before/after/first/last can be set. Default behaviour is "last".
+
+        factory_name (str): Name of the component factory.
+        name (str): Name of pipeline component. Overwrites existing
+            component.name attribute if available. If no name is set and
+            the component exposes no name attribute, component.__name__ is
+            used. An error is raised if a name already exists in the pipeline.
+        before (Union[str, int]): Name or index of the component to insert new
+            component directly before.
+        after (Union[str, int]): Name or index of the component to insert new
+            component directly after.
+        first (bool): If True, insert component first in the pipeline.
+        last (bool): If True, insert component last in the pipeline.
+        source (Language): Optional loaded nlp object to copy the pipeline
+            component from.
+        config (Dict[str, Any]): Config parameters to use for this component.
+            Will be merged with default config, if available.
+        raw_config (Optional[Config]): Internals: the non-interpolated config.
+        validate (bool): Whether to validate the component config against the
+            arguments and types expected by the factory.
+        RETURNS (Callable[[Doc], Doc]): The pipeline component.
+
+        DOCS: https://spacy.io/api/language#add_pipe
+        """
+        if not isinstance(factory_name, str):
+            bad_val = repr(factory_name)
+            err = Errors.E966.format(component=bad_val, name=name)
+            raise ValueError(err)
+        name = name if name is not None else factory_name
+        if name in self.component_names:
+            raise ValueError(Errors.E007.format(name=name, opts=self.component_names))
+        # Overriding pipe name in the config is not supported and will be ignored.
+        if "name" in config:
+            warnings.warn(Warnings.W119.format(name_in_config=config.pop("name")))
+        if source is not None:
+            # We're loading the component from a model. After loading the
+            # component, we know its real factory name
+            pipe_component, factory_name = self.create_pipe_from_source(
+                factory_name, source, name=name
+            )
+        else:
+            pipe_component = self.create_pipe(
+                factory_name,
+                name=name,
+                config=config,
+                raw_config=raw_config,
+                validate=validate,
+            )
+        pipe_index = self._get_pipe_index(before, after, first, last)
+        self._pipe_meta[name] = self.get_factory_meta(factory_name)
+        self._components.insert(pipe_index, (name, pipe_component))
+        self._link_components()
+        return pipe_component
+
+    def _get_pipe_index(
+        self,
+        before: Optional[Union[str, int]] = None,
+        after: Optional[Union[str, int]] = None,
+        first: Optional[bool] = None,
+        last: Optional[bool] = None,
+    ) -> int:
+        """Determine where to insert a pipeline component based on the before/
+        after/first/last values.
+
+        before (str): Name or index of the component to insert directly before.
+        after (str): Name or index of component to insert directly after.
+        first (bool): If True, insert component first in the pipeline.
+        last (bool): If True, insert component last in the pipeline.
+        RETURNS (int): The index of the new pipeline component.
+        """
+        all_args = {"before": before, "after": after, "first": first, "last": last}
+        if sum(arg is not None for arg in [before, after, first, last]) >= 2:
+            raise ValueError(
+                Errors.E006.format(args=all_args, opts=self.component_names)
+            )
+        if last or not any(value is not None for value in [first, before, after]):
+            return len(self._components)
+        elif first:
+            return 0
+        elif isinstance(before, str):
+            if before not in self.component_names:
+                raise ValueError(
+                    Errors.E001.format(name=before, opts=self.component_names)
+                )
+            return self.component_names.index(before)
+        elif isinstance(after, str):
+            if after not in self.component_names:
+                raise ValueError(
+                    Errors.E001.format(name=after, opts=self.component_names)
+                )
+            return self.component_names.index(after) + 1
+        # We're only accepting indices referring to components that exist
+        # (can't just do isinstance here because bools are instance of int, too)
+        elif type(before) == int:
+            if before >= len(self._components) or before < 0:
+                err = Errors.E959.format(
+                    dir="before", idx=before, opts=self.component_names
+                )
+                raise ValueError(err)
+            return before
+        elif type(after) == int:
+            if after >= len(self._components) or after < 0:
+                err = Errors.E959.format(
+                    dir="after", idx=after, opts=self.component_names
+                )
+                raise ValueError(err)
+            return after + 1
+        raise ValueError(Errors.E006.format(args=all_args, opts=self.component_names))
+
+    def has_pipe(self, name: str) -> bool:
+        """Check if a component name is present in the pipeline. Equivalent to
+        `name in nlp.pipe_names`.
+
+        name (str): Name of the component.
+        RETURNS (bool): Whether a component of the name exists in the pipeline.
+
+        DOCS: https://spacy.io/api/language#has_pipe
+        """
+        return name in self.pipe_names
+
+    def replace_pipe(
+        self,
+        name: str,
+        factory_name: str,
+        *,
+        config: Dict[str, Any] = SimpleFrozenDict(),
+        validate: bool = True,
+    ) -> PipeCallable:
+        """Replace a component in the pipeline.
+
+        name (str): Name of the component to replace.
+        factory_name (str): Factory name of replacement component.
+        config (Optional[Dict[str, Any]]): Config parameters to use for this
+            component. Will be merged with default config, if available.
+        validate (bool): Whether to validate the component config against the
+            arguments and types expected by the factory.
+        RETURNS (Callable[[Doc], Doc]): The new pipeline component.
+
+        DOCS: https://spacy.io/api/language#replace_pipe
+        """
+        if name not in self.component_names:
+            raise ValueError(Errors.E001.format(name=name, opts=self.pipe_names))
+        if hasattr(factory_name, "__call__"):
+            err = Errors.E968.format(component=repr(factory_name), name=name)
+            raise ValueError(err)
+        # We need to delegate to Language.add_pipe here instead of just writing
+        # to Language.pipeline to make sure the configs are handled correctly
+        pipe_index = self.component_names.index(name)
+        self.remove_pipe(name)
+        if not len(self._components) or pipe_index == len(self._components):
+            # we have no components to insert before/after, or we're replacing the last component
+            return self.add_pipe(
+                factory_name, name=name, config=config, validate=validate
+            )
+        else:
+            return self.add_pipe(
+                factory_name,
+                name=name,
+                before=pipe_index,
+                config=config,
+                validate=validate,
+            )
+
+    def rename_pipe(self, old_name: str, new_name: str) -> None:
+        """Rename a pipeline component.
+
+        old_name (str): Name of the component to rename.
+        new_name (str): New name of the component.
+
+        DOCS: https://spacy.io/api/language#rename_pipe
+        """
+        if old_name not in self.component_names:
+            raise ValueError(
+                Errors.E001.format(name=old_name, opts=self.component_names)
+            )
+        if new_name in self.component_names:
+            raise ValueError(
+                Errors.E007.format(name=new_name, opts=self.component_names)
+            )
+        i = self.component_names.index(old_name)
+        self._components[i] = (new_name, self._components[i][1])
+        self._pipe_meta[new_name] = self._pipe_meta.pop(old_name)
+        self._pipe_configs[new_name] = self._pipe_configs.pop(old_name)
+        # Make sure [initialize] config is adjusted
+        if old_name in self._config["initialize"]["components"]:
+            init_cfg = self._config["initialize"]["components"].pop(old_name)
+            self._config["initialize"]["components"][new_name] = init_cfg
+        self._link_components()
+
+    def remove_pipe(self, name: str) -> Tuple[str, PipeCallable]:
+        """Remove a component from the pipeline.
+
+        name (str): Name of the component to remove.
+        RETURNS (Tuple[str, Callable[[Doc], Doc]]): A `(name, component)` tuple of the removed component.
+
+        DOCS: https://spacy.io/api/language#remove_pipe
+        """
+        if name not in self.component_names:
+            raise ValueError(Errors.E001.format(name=name, opts=self.component_names))
+        removed = self._components.pop(self.component_names.index(name))
+        # We're only removing the component itself from the metas/configs here
+        # because factory may be used for something else
+        self._pipe_meta.pop(name)
+        self._pipe_configs.pop(name)
+        self.meta.get("_sourced_vectors_hashes", {}).pop(name, None)
+        # Make sure name is removed from the [initialize] config
+        if name in self._config["initialize"]["components"]:
+            self._config["initialize"]["components"].pop(name)
+        # Make sure the name is also removed from the set of disabled components
+        if name in self.disabled:
+            self._disabled.remove(name)
+        self._link_components()
+        return removed
+
+    def disable_pipe(self, name: str) -> None:
+        """Disable a pipeline component. The component will still exist on
+        the nlp object, but it won't be run as part of the pipeline. Does
+        nothing if the component is already disabled.
+
+        name (str): The name of the component to disable.
+        """
+        if name not in self.component_names:
+            raise ValueError(Errors.E001.format(name=name, opts=self.component_names))
+        self._disabled.add(name)
+
+    def enable_pipe(self, name: str) -> None:
+        """Enable a previously disabled pipeline component so it's run as part
+        of the pipeline. Does nothing if the component is already enabled.
+
+        name (str): The name of the component to enable.
+        """
+        if name not in self.component_names:
+            raise ValueError(Errors.E001.format(name=name, opts=self.component_names))
+        if name in self.disabled:
+            self._disabled.remove(name)
+
+    def __call__(
+        self,
+        text: Union[str, Doc],
+        *,
+        disable: Iterable[str] = SimpleFrozenList(),
+        component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
+    ) -> Doc:
+        """Apply the pipeline to some text. The text can span multiple sentences,
+        and can contain arbitrary whitespace. Alignment into the original string
+        is preserved.
+
+        text (Union[str, Doc]): If `str`, the text to be processed. If `Doc`,
+            the doc will be passed directly to the pipeline, skipping
+            `Language.make_doc`.
+        disable (List[str]): Names of the pipeline components to disable.
+        component_cfg (Dict[str, dict]): An optional dictionary with extra
+            keyword arguments for specific components.
+        RETURNS (Doc): A container for accessing the annotations.
+
+        DOCS: https://spacy.io/api/language#call
+        """
+        doc = self._ensure_doc(text)
+        if component_cfg is None:
+            component_cfg = {}
+        for name, proc in self.pipeline:
+            if name in disable:
+                continue
+            if not hasattr(proc, "__call__"):
+                raise ValueError(Errors.E003.format(component=type(proc), name=name))
+            error_handler = self.default_error_handler
+            if hasattr(proc, "get_error_handler"):
+                error_handler = proc.get_error_handler()
+            try:
+                doc = proc(doc, **component_cfg.get(name, {}))  # type: ignore[call-arg]
+            except KeyError as e:
+                # This typically happens if a component is not initialized
+                raise ValueError(Errors.E109.format(name=name)) from e
+            except Exception as e:
+                error_handler(name, proc, [doc], e)
+            if not isinstance(doc, Doc):
+                raise ValueError(Errors.E005.format(name=name, returned_type=type(doc)))
+        return doc
+
+    def disable_pipes(self, *names) -> "DisabledPipes":
+        """Disable one or more pipeline components. If used as a context
+        manager, the pipeline will be restored to the initial state at the end
+        of the block. Otherwise, a DisabledPipes object is returned, that has
+        a `.restore()` method you can use to undo your changes.
+
+        This method has been deprecated since 3.0
+        """
+        warnings.warn(Warnings.W096, DeprecationWarning)
+        if len(names) == 1 and isinstance(names[0], (list, tuple)):
+            names = names[0]  # type: ignore[assignment]    # support list of names instead of spread
+        return self.select_pipes(disable=names)
+
+    def select_pipes(
+        self,
+        *,
+        disable: Optional[Union[str, Iterable[str]]] = None,
+        enable: Optional[Union[str, Iterable[str]]] = None,
+    ) -> "DisabledPipes":
+        """Disable one or more pipeline components. If used as a context
+        manager, the pipeline will be restored to the initial state at the end
+        of the block. Otherwise, a DisabledPipes object is returned, that has
+        a `.restore()` method you can use to undo your changes.
+
+        disable (str or iterable): The name(s) of the pipes to disable
+        enable (str or iterable): The name(s) of the pipes to enable - all others will be disabled
+
+        DOCS: https://spacy.io/api/language#select_pipes
+        """
+        if enable is None and disable is None:
+            raise ValueError(Errors.E991)
+        if isinstance(disable, str):
+            disable = [disable]
+        if enable is not None:
+            if isinstance(enable, str):
+                enable = [enable]
+            to_disable = [pipe for pipe in self.pipe_names if pipe not in enable]
+            # raise an error if the enable and disable keywords are not consistent
+            if disable is not None and disable != to_disable:
+                raise ValueError(
+                    Errors.E992.format(
+                        enable=enable, disable=disable, names=self.pipe_names
+                    )
+                )
+            disable = to_disable
+        assert disable is not None
+        # DisabledPipes will restore the pipes in 'disable' when it's done, so we need to exclude
+        # those pipes that were already disabled.
+        disable = [d for d in disable if d not in self._disabled]
+        return DisabledPipes(self, disable)
+
+    def make_doc(self, text: str) -> Doc:
+        """Turn a text into a Doc object.
+
+        text (str): The text to process.
+        RETURNS (Doc): The processed doc.
+        """
+        if len(text) > self.max_length:
+            raise ValueError(
+                Errors.E088.format(length=len(text), max_length=self.max_length)
+            )
+        return self.tokenizer(text)
+
+    def _ensure_doc(self, doc_like: Union[str, Doc, bytes]) -> Doc:
+        """Create a Doc if need be, or raise an error if the input is not
+        a Doc, string, or a byte array (generated by Doc.to_bytes())."""
+        if isinstance(doc_like, Doc):
+            return doc_like
+        if isinstance(doc_like, str):
+            return self.make_doc(doc_like)
+        if isinstance(doc_like, bytes):
+            return Doc(self.vocab).from_bytes(doc_like)
+        raise ValueError(Errors.E1041.format(type=type(doc_like)))
+
+    def _ensure_doc_with_context(
+        self, doc_like: Union[str, Doc, bytes], context: _AnyContext
+    ) -> Doc:
+        """Call _ensure_doc to generate a Doc and set its context object."""
+        doc = self._ensure_doc(doc_like)
+        doc._context = context
+        return doc
+
+    def update(
+        self,
+        examples: Iterable[Example],
+        _: Optional[Any] = None,
+        *,
+        drop: float = 0.0,
+        sgd: Optional[Optimizer] = None,
+        losses: Optional[Dict[str, float]] = None,
+        component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
+        exclude: Iterable[str] = SimpleFrozenList(),
+        annotates: Iterable[str] = SimpleFrozenList(),
+    ):
+        """Update the models in the pipeline.
+
+        examples (Iterable[Example]): A batch of examples
+        _: Should not be set - serves to catch backwards-incompatible scripts.
+        drop (float): The dropout rate.
+        sgd (Optimizer): An optimizer.
+        losses (Dict[str, float]): Dictionary to update with the loss, keyed by
+            component.
+        component_cfg (Dict[str, Dict]): Config parameters for specific pipeline
+            components, keyed by component name.
+        exclude (Iterable[str]): Names of components that shouldn't be updated.
+        annotates (Iterable[str]): Names of components that should set
+            annotations on the predicted examples after updating.
+        RETURNS (Dict[str, float]): The updated losses dictionary
+
+        DOCS: https://spacy.io/api/language#update
+        """
+        if _ is not None:
+            raise ValueError(Errors.E989)
+        if losses is None:
+            losses = {}
+        if isinstance(examples, list) and len(examples) == 0:
+            return losses
+        validate_examples(examples, "Language.update")
+        examples = _copy_examples(examples)
+        if sgd is None:
+            if self._optimizer is None:
+                self._optimizer = self.create_optimizer()
+            sgd = self._optimizer
+        if component_cfg is None:
+            component_cfg = {}
+        pipe_kwargs = {}
+        for i, (name, proc) in enumerate(self.pipeline):
+            component_cfg.setdefault(name, {})
+            pipe_kwargs[name] = deepcopy(component_cfg[name])
+            component_cfg[name].setdefault("drop", drop)
+            pipe_kwargs[name].setdefault("batch_size", self.batch_size)
+        for name, proc in self.pipeline:
+            # ignore statements are used here because mypy ignores hasattr
+            if name not in exclude and hasattr(proc, "update"):
+                proc.update(examples, sgd=None, losses=losses, **component_cfg[name])  # type: ignore
+            if sgd not in (None, False):
+                if (
+                    name not in exclude
+                    and isinstance(proc, ty.TrainableComponent)
+                    and proc.is_trainable
+                    and proc.model not in (True, False, None)
+                ):
+                    proc.finish_update(sgd)
+            if name in annotates:
+                for doc, eg in zip(
+                    _pipe(
+                        (eg.predicted for eg in examples),
+                        proc=proc,
+                        name=name,
+                        default_error_handler=self.default_error_handler,
+                        kwargs=pipe_kwargs[name],
+                    ),
+                    examples,
+                ):
+                    eg.predicted = doc
+        return _replace_numpy_floats(losses)
+
+    def rehearse(
+        self,
+        examples: Iterable[Example],
+        *,
+        sgd: Optional[Optimizer] = None,
+        losses: Optional[Dict[str, float]] = None,
+        component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
+        exclude: Iterable[str] = SimpleFrozenList(),
+    ) -> Dict[str, float]:
+        """Make a "rehearsal" update to the models in the pipeline, to prevent
+        forgetting. Rehearsal updates run an initial copy of the model over some
+        data, and update the model so its current predictions are more like the
+        initial ones. This is useful for keeping a pretrained model on-track,
+        even if you're updating it with a smaller set of examples.
+
+        examples (Iterable[Example]): A batch of `Example` objects.
+        sgd (Optional[Optimizer]): An optimizer.
+        component_cfg (Dict[str, Dict]): Config parameters for specific pipeline
+            components, keyed by component name.
+        exclude (Iterable[str]): Names of components that shouldn't be updated.
+        RETURNS (dict): Results from the update.
+
+        EXAMPLE:
+            >>> raw_text_batches = minibatch(raw_texts)
+            >>> for labelled_batch in minibatch(examples):
+            >>>     nlp.update(labelled_batch)
+            >>>     raw_batch = [Example.from_dict(nlp.make_doc(text), {}) for text in next(raw_text_batches)]
+            >>>     nlp.rehearse(raw_batch)
+
+        DOCS: https://spacy.io/api/language#rehearse
+        """
+        if losses is None:
+            losses = {}
+        if isinstance(examples, list) and len(examples) == 0:
+            return losses
+        validate_examples(examples, "Language.rehearse")
+        if sgd is None:
+            if self._optimizer is None:
+                self._optimizer = self.create_optimizer()
+            sgd = self._optimizer
+        pipes = list(self.pipeline)
+        random.shuffle(pipes)
+        if component_cfg is None:
+            component_cfg = {}
+        grads = {}
+
+        def get_grads(key, W, dW):
+            grads[key] = (W, dW)
+            return W, dW
+
+        get_grads.learn_rate = sgd.learn_rate  # type: ignore[attr-defined, union-attr]
+        get_grads.b1 = sgd.b1  # type: ignore[attr-defined, union-attr]
+        get_grads.b2 = sgd.b2  # type: ignore[attr-defined, union-attr]
+        for name, proc in pipes:
+            if name in exclude or not hasattr(proc, "rehearse"):
+                continue
+            grads = {}
+            proc.rehearse(  # type: ignore[attr-defined]
+                examples, sgd=get_grads, losses=losses, **component_cfg.get(name, {})
+            )
+        for key, (W, dW) in grads.items():
+            sgd(key, W, dW)  # type: ignore[call-arg, misc]
+        return losses
+
+    def begin_training(
+        self,
+        get_examples: Optional[Callable[[], Iterable[Example]]] = None,
+        *,
+        sgd: Optional[Optimizer] = None,
+    ) -> Optimizer:
+        warnings.warn(Warnings.W089, DeprecationWarning)
+        return self.initialize(get_examples, sgd=sgd)
+
+    def initialize(
+        self,
+        get_examples: Optional[Callable[[], Iterable[Example]]] = None,
+        *,
+        sgd: Optional[Optimizer] = None,
+    ) -> Optimizer:
+        """Initialize the pipe for training, using data examples if available.
+
+        get_examples (Callable[[], Iterable[Example]]): Optional function that
+            returns gold-standard Example objects.
+        sgd (Optional[Optimizer]): An optimizer to use for updates. If not
+            provided, will be created using the .create_optimizer() method.
+        RETURNS (thinc.api.Optimizer): The optimizer.
+
+        DOCS: https://spacy.io/api/language#initialize
+        """
+        if get_examples is None:
+            util.logger.debug(
+                "No 'get_examples' callback provided to 'Language.initialize', creating dummy examples"
+            )
+            doc = Doc(self.vocab, words=["x", "y", "z"])
+
+            def get_examples():
+                return [Example.from_dict(doc, {})]
+
+        if not hasattr(get_examples, "__call__"):
+            err = Errors.E930.format(
+                method="Language.initialize", obj=type(get_examples)
+            )
+            raise TypeError(err)
+        # Make sure the config is interpolated so we can resolve subsections
+        config = self.config.interpolate()
+        # These are the settings provided in the [initialize] block in the config
+        I = registry.resolve(config["initialize"], schema=ConfigSchemaInit)
+        before_init = I["before_init"]
+        if before_init is not None:
+            before_init(self)
+        try:
+            init_vocab(
+                self, data=I["vocab_data"], lookups=I["lookups"], vectors=I["vectors"]
+            )
+        except IOError:
+            raise IOError(Errors.E884.format(vectors=I["vectors"]))
+        if self.vocab.vectors.shape[1] >= 1:
+            ops = get_current_ops()
+            self.vocab.vectors.to_ops(ops)
+        if hasattr(self.tokenizer, "initialize"):
+            tok_settings = validate_init_settings(
+                self.tokenizer.initialize,  # type: ignore[union-attr]
+                I["tokenizer"],
+                section="tokenizer",
+                name="tokenizer",
+            )
+            self.tokenizer.initialize(get_examples, nlp=self, **tok_settings)  # type: ignore[union-attr]
+        for name, proc in self.pipeline:
+            if isinstance(proc, ty.InitializableComponent):
+                p_settings = I["components"].get(name, {})
+                p_settings = validate_init_settings(
+                    proc.initialize, p_settings, section="components", name=name
+                )
+                proc.initialize(get_examples, nlp=self, **p_settings)
+        pretrain_cfg = config.get("pretraining")
+        if pretrain_cfg:
+            P = registry.resolve(pretrain_cfg, schema=ConfigSchemaPretrain)
+            init_tok2vec(self, P, I)
+        self._link_components()
+        self._optimizer = sgd
+        if sgd is not None:
+            self._optimizer = sgd
+        elif self._optimizer is None:
+            self._optimizer = self.create_optimizer()
+        after_init = I["after_init"]
+        if after_init is not None:
+            after_init(self)
+        return self._optimizer
+
+    def resume_training(self, *, sgd: Optional[Optimizer] = None) -> Optimizer:
+        """Continue training a pretrained model.
+
+        Create and return an optimizer, and initialize "rehearsal" for any pipeline
+        component that has a .rehearse() method. Rehearsal is used to prevent
+        models from "forgetting" their initialized "knowledge". To perform
+        rehearsal, collect samples of text you want the models to retain performance
+        on, and call nlp.rehearse() with a batch of Example objects.
+
+        RETURNS (Optimizer): The optimizer.
+
+        DOCS: https://spacy.io/api/language#resume_training
+        """
+        ops = get_current_ops()
+        if self.vocab.vectors.shape[1] >= 1:
+            self.vocab.vectors.to_ops(ops)
+        for name, proc in self.pipeline:
+            if hasattr(proc, "_rehearsal_model"):
+                proc._rehearsal_model = deepcopy(proc.model)  # type: ignore[attr-defined]
+        if sgd is not None:
+            self._optimizer = sgd
+        elif self._optimizer is None:
+            self._optimizer = self.create_optimizer()
+        return self._optimizer
+
+    def set_error_handler(
+        self,
+        error_handler: Callable[[str, PipeCallable, List[Doc], Exception], NoReturn],
+    ):
+        """Set an error handler object for all the components in the pipeline
+        that implement a set_error_handler function.
+
+        error_handler (Callable[[str, Callable[[Doc], Doc], List[Doc], Exception], NoReturn]):
+            Function that deals with a failing batch of documents. This callable
+            function should take in the component's name, the component itself,
+            the offending batch of documents, and the exception that was thrown.
+        DOCS: https://spacy.io/api/language#set_error_handler
+        """
+        self.default_error_handler = error_handler
+        for name, pipe in self.pipeline:
+            if hasattr(pipe, "set_error_handler"):
+                pipe.set_error_handler(error_handler)
+
+    def evaluate(
+        self,
+        examples: Iterable[Example],
+        *,
+        batch_size: Optional[int] = None,
+        scorer: Optional[Scorer] = None,
+        component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
+        scorer_cfg: Optional[Dict[str, Any]] = None,
+        per_component: bool = False,
+    ) -> Dict[str, Any]:
+        """Evaluate a model's pipeline components.
+
+        examples (Iterable[Example]): `Example` objects.
+        batch_size (Optional[int]): Batch size to use.
+        scorer (Optional[Scorer]): Scorer to use. If not passed in, a new one
+            will be created.
+        component_cfg (dict): An optional dictionary with extra keyword
+            arguments for specific components.
+        scorer_cfg (dict): An optional dictionary with extra keyword arguments
+            for the scorer.
+        per_component (bool): Whether to return the scores keyed by component
+            name. Defaults to False.
+
+        RETURNS (Scorer): The scorer containing the evaluation results.
+
+        DOCS: https://spacy.io/api/language#evaluate
+        """
+        examples = list(examples)
+        validate_examples(examples, "Language.evaluate")
+        examples = _copy_examples(examples)
+        if batch_size is None:
+            batch_size = self.batch_size
+        if component_cfg is None:
+            component_cfg = {}
+        if scorer_cfg is None:
+            scorer_cfg = {}
+        if scorer is None:
+            kwargs = dict(scorer_cfg)
+            kwargs.setdefault("nlp", self)
+            scorer = Scorer(**kwargs)
+        # reset annotation in predicted docs and time tokenization
+        start_time = timer()
+        # this is purely for timing
+        for eg in examples:
+            self.make_doc(eg.reference.text)
+        # apply all pipeline components
+        docs = self.pipe(
+            (eg.predicted for eg in examples),
+            batch_size=batch_size,
+            component_cfg=component_cfg,
+        )
+        for eg, doc in zip(examples, docs):
+            eg.predicted = doc
+        end_time = timer()
+        results = scorer.score(examples, per_component=per_component)
+        n_words = sum(len(eg.predicted) for eg in examples)
+        results["speed"] = n_words / (end_time - start_time)
+        return _replace_numpy_floats(results)
+
+    def create_optimizer(self):
+        """Create an optimizer, usually using the [training.optimizer] config."""
+        subconfig = {"optimizer": self.config["training"]["optimizer"]}
+        return registry.resolve(subconfig)["optimizer"]
+
+    @contextmanager
+    def use_params(self, params: Optional[dict]):
+        """Replace weights of models in the pipeline with those provided in the
+        params dictionary. Can be used as a contextmanager, in which case,
+        models go back to their original weights after the block.
+
+        params (dict): A dictionary of parameters keyed by model ID.
+
+        EXAMPLE:
+            >>> with nlp.use_params(optimizer.averages):
+            >>>     nlp.to_disk("/tmp/checkpoint")
+
+        DOCS: https://spacy.io/api/language#use_params
+        """
+        if not params:
+            yield
+        else:
+            contexts = [
+                pipe.use_params(params)  # type: ignore[attr-defined]
+                for name, pipe in self.pipeline
+                if hasattr(pipe, "use_params") and hasattr(pipe, "model")
+            ]
+            # TODO: Having trouble with contextlib
+            # Workaround: these aren't actually context managers atm.
+            for context in contexts:
+                try:
+                    next(context)
+                except StopIteration:
+                    pass
+            yield
+            for context in contexts:
+                try:
+                    next(context)
+                except StopIteration:
+                    pass
+
+    @overload
+    def pipe(
+        self,
+        texts: Iterable[Union[str, Doc]],
+        *,
+        as_tuples: Literal[False] = ...,
+        batch_size: Optional[int] = ...,
+        disable: Iterable[str] = ...,
+        component_cfg: Optional[Dict[str, Dict[str, Any]]] = ...,
+        n_process: int = ...,
+    ) -> Iterator[Doc]:
+        ...
+
+    @overload
+    def pipe(  # noqa: F811
+        self,
+        texts: Iterable[Tuple[Union[str, Doc], _AnyContext]],
+        *,
+        as_tuples: Literal[True] = ...,
+        batch_size: Optional[int] = ...,
+        disable: Iterable[str] = ...,
+        component_cfg: Optional[Dict[str, Dict[str, Any]]] = ...,
+        n_process: int = ...,
+    ) -> Iterator[Tuple[Doc, _AnyContext]]:
+        ...
+
+    def pipe(  # noqa: F811
+        self,
+        texts: Union[
+            Iterable[Union[str, Doc]], Iterable[Tuple[Union[str, Doc], _AnyContext]]
+        ],
+        *,
+        as_tuples: bool = False,
+        batch_size: Optional[int] = None,
+        disable: Iterable[str] = SimpleFrozenList(),
+        component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
+        n_process: int = 1,
+    ) -> Union[Iterator[Doc], Iterator[Tuple[Doc, _AnyContext]]]:
+        """Process texts as a stream, and yield `Doc` objects in order.
+
+        texts (Iterable[Union[str, Doc]]): A sequence of texts or docs to
+            process.
+        as_tuples (bool): If set to True, inputs should be a sequence of
+            (text, context) tuples. Output will then be a sequence of
+            (doc, context) tuples. Defaults to False.
+        batch_size (Optional[int]): The number of texts to buffer.
+        disable (List[str]): Names of the pipeline components to disable.
+        component_cfg (Dict[str, Dict]): An optional dictionary with extra keyword
+            arguments for specific components.
+        n_process (int): Number of processors to process texts. If -1, set `multiprocessing.cpu_count()`.
+        YIELDS (Doc): Documents in the order of the original text.
+
+        DOCS: https://spacy.io/api/language#pipe
+        """
+        if as_tuples:
+            texts = cast(Iterable[Tuple[Union[str, Doc], _AnyContext]], texts)
+            docs_with_contexts = (
+                self._ensure_doc_with_context(text, context) for text, context in texts
+            )
+            docs = self.pipe(
+                docs_with_contexts,
+                batch_size=batch_size,
+                disable=disable,
+                n_process=n_process,
+                component_cfg=component_cfg,
+            )
+            for doc in docs:
+                context = doc._context
+                doc._context = None
+                yield (doc, context)
+            return
+
+        texts = cast(Iterable[Union[str, Doc]], texts)
+
+        # Set argument defaults
+        if n_process == -1:
+            n_process = mp.cpu_count()
+        if component_cfg is None:
+            component_cfg = {}
+        if batch_size is None:
+            batch_size = self.batch_size
+
+        pipes = (
+            []
+        )  # contains functools.partial objects to easily create multiprocess worker.
+        for name, proc in self.pipeline:
+            if name in disable:
+                continue
+            kwargs = component_cfg.get(name, {})
+            # Allow component_cfg to overwrite the top-level kwargs.
+            kwargs.setdefault("batch_size", batch_size)
+            f = functools.partial(
+                _pipe,
+                proc=proc,
+                name=name,
+                kwargs=kwargs,
+                default_error_handler=self.default_error_handler,
+            )
+            pipes.append(f)
+
+        if n_process != 1:
+            if self._has_gpu_model(disable):
+                warnings.warn(Warnings.W114)
+
+            docs = self._multiprocessing_pipe(texts, pipes, n_process, batch_size)
+        else:
+            # if n_process == 1, no processes are forked.
+            docs = (self._ensure_doc(text) for text in texts)
+            for pipe in pipes:
+                docs = pipe(docs)
+        for doc in docs:
+            yield doc
+
+    def _has_gpu_model(self, disable: Iterable[str]):
+        for name, proc in self.pipeline:
+            is_trainable = hasattr(proc, "is_trainable") and proc.is_trainable  # type: ignore
+            if name in disable or not is_trainable:
+                continue
+
+            if hasattr(proc, "model") and hasattr(proc.model, "ops") and isinstance(proc.model.ops, CupyOps):  # type: ignore
+                return True
+
+        return False
+
+    def _multiprocessing_pipe(
+        self,
+        texts: Iterable[Union[str, Doc]],
+        pipes: Iterable[Callable[..., Iterator[Doc]]],
+        n_process: int,
+        batch_size: int,
+    ) -> Iterator[Doc]:
+        def prepare_input(
+            texts: Iterable[Union[str, Doc]]
+        ) -> Iterable[Tuple[Union[str, bytes], _AnyContext]]:
+            # Serialize Doc inputs to bytes to avoid incurring pickling
+            # overhead when they are passed to child processes. Also yield
+            # any context objects they might have separately (as they are not serialized).
+            for doc_like in texts:
+                if isinstance(doc_like, Doc):
+                    yield (doc_like.to_bytes(), cast(_AnyContext, doc_like._context))
+                else:
+                    yield (doc_like, cast(_AnyContext, None))
+
+        serialized_texts_with_ctx = prepare_input(texts)  # type: ignore
+        # raw_texts is used later to stop iteration.
+        texts, raw_texts = itertools.tee(serialized_texts_with_ctx)  # type: ignore
+        # for sending texts to worker
+        texts_q: List[mp.Queue] = [mp.Queue() for _ in range(n_process)]
+        # for receiving byte-encoded docs from worker
+        bytedocs_recv_ch, bytedocs_send_ch = zip(
+            *[mp.Pipe(False) for _ in range(n_process)]
+        )
+
+        batch_texts = util.minibatch(texts, batch_size)
+        # Sender sends texts to the workers.
+        # This is necessary to properly handle infinite length of texts.
+        # (In this case, all data cannot be sent to the workers at once)
+        sender = _Sender(batch_texts, texts_q, chunk_size=n_process)
+        # send twice to make process busy
+        sender.send()
+        sender.send()
+
+        procs = [
+            mp.Process(
+                target=_apply_pipes,
+                args=(
+                    self._ensure_doc_with_context,
+                    pipes,
+                    rch,
+                    sch,
+                    Underscore.get_state(),
+                ),
+            )
+            for rch, sch in zip(texts_q, bytedocs_send_ch)
+        ]
+        for proc in procs:
+            proc.start()
+
+        # Close writing-end of channels. This is needed to avoid that reading
+        # from the channel blocks indefinitely when the worker closes the
+        # channel.
+        for tx in bytedocs_send_ch:
+            tx.close()
+
+        # Cycle channels not to break the order of docs.
+        # The received object is a batch of byte-encoded docs, so flatten them with chain.from_iterable.
+        byte_tuples = chain.from_iterable(
+            recv.recv() for recv in cycle(bytedocs_recv_ch)
+        )
+        try:
+            for i, (_, (byte_doc, context, byte_error)) in enumerate(
+                zip(raw_texts, byte_tuples), 1
+            ):
+                if byte_doc is not None:
+                    doc = Doc(self.vocab).from_bytes(byte_doc)
+                    doc._context = context
+                    yield doc
+                elif byte_error is not None:
+                    error = srsly.msgpack_loads(byte_error)
+                    self.default_error_handler(
+                        None, None, None, ValueError(Errors.E871.format(error=error))
+                    )
+                if i % batch_size == 0:
+                    # tell `sender` that one batch was consumed.
+                    sender.step()
+        finally:
+            # If we are stopping in an orderly fashion, the workers' queues
+            # are empty. Put the sentinel in their queues to signal that work
+            # is done, so that they can exit gracefully.
+            for q in texts_q:
+                q.put(_WORK_DONE_SENTINEL)
+                q.close()
+
+            # Otherwise, we are stopping because the error handler raised an
+            # exception. The sentinel will be last to go out of the queue.
+            # To avoid doing unnecessary work or hanging on platforms that
+            # block on sending (Windows), we'll close our end of the channel.
+            # This signals to the worker that it can exit the next time it
+            # attempts to send data down the channel.
+            for r in bytedocs_recv_ch:
+                r.close()
+
+            for proc in procs:
+                proc.join()
+
+            if not all(proc.exitcode == 0 for proc in procs):
+                warnings.warn(Warnings.W127)
+
+    def _link_components(self) -> None:
+        """Register 'listeners' within pipeline components, to allow them to
+        effectively share weights.
+        """
+        # I had thought, "Why do we do this inside the Language object? Shouldn't
+        # it be the tok2vec/transformer/etc's job?
+        # The problem is we need to do it during deserialization...And the
+        # components don't receive the pipeline then. So this does have to be
+        # here :(
+        # First, fix up all the internal component names in case they have
+        # gotten out of sync due to sourcing components from different
+        # pipelines, since find_listeners uses proc2.name for the listener
+        # map.
+        for name, proc in self.pipeline:
+            if hasattr(proc, "name"):
+                proc.name = name
+        for i, (name1, proc1) in enumerate(self.pipeline):
+            if isinstance(proc1, ty.ListenedToComponent):
+                proc1.listener_map = {}
+                for name2, proc2 in self.pipeline[i + 1 :]:
+                    proc1.find_listeners(proc2)
+
+    @classmethod
+    def from_config(
+        cls,
+        config: Union[Dict[str, Any], Config] = {},
+        *,
+        vocab: Union[Vocab, bool] = True,
+        disable: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES,
+        enable: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES,
+        exclude: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES,
+        meta: Dict[str, Any] = SimpleFrozenDict(),
+        auto_fill: bool = True,
+        validate: bool = True,
+    ) -> "Language":
+        """Create the nlp object from a loaded config. Will set up the tokenizer
+        and language data, add pipeline components etc. If no config is provided,
+        the default config of the given language is used.
+
+        config (Dict[str, Any] / Config): The loaded config.
+        vocab (Vocab): A Vocab object. If True, a vocab is created.
+        disable (Union[str, Iterable[str]]): Name(s) of pipeline component(s) to disable.
+            Disabled pipes will be loaded but they won't be run unless you
+            explicitly enable them by calling nlp.enable_pipe.
+        enable (Union[str, Iterable[str]]): Name(s) of pipeline component(s) to enable. All other
+            pipes will be disabled (and can be enabled using `nlp.enable_pipe`).
+        exclude (Union[str, Iterable[str]]): Name(s) of pipeline component(s) to exclude.
+            Excluded components won't be loaded.
+        meta (Dict[str, Any]): Meta overrides for nlp.meta.
+        auto_fill (bool): Automatically fill in missing values in config based
+            on defaults and function argument annotations.
+        validate (bool): Validate the component config and arguments against
+            the types expected by the factory.
+        RETURNS (Language): The initialized Language class.
+
+        DOCS: https://spacy.io/api/language#from_config
+        """
+        if auto_fill:
+            config = Config(
+                cls.default_config, section_order=CONFIG_SECTION_ORDER
+            ).merge(config)
+        if "nlp" not in config:
+            raise ValueError(Errors.E985.format(config=config))
+        # fill in [nlp.vectors] if not present (as a narrower alternative to
+        # auto-filling [nlp] from the default config)
+        if "vectors" not in config["nlp"]:
+            config["nlp"]["vectors"] = {"@vectors": "spacy.Vectors.v1"}
+        config_lang = config["nlp"].get("lang")
+        if config_lang is not None and config_lang != cls.lang:
+            raise ValueError(
+                Errors.E958.format(
+                    bad_lang_code=config["nlp"]["lang"],
+                    lang_code=cls.lang,
+                    lang=util.get_object_name(cls),
+                )
+            )
+        config["nlp"]["lang"] = cls.lang
+        # This isn't very elegant, but we remove the [components] block here to prevent
+        # it from getting resolved (causes problems because we expect to pass in
+        # the nlp and name args for each component). If we're auto-filling, we're
+        # using the nlp.config with all defaults.
+        config = util.copy_config(config)
+        orig_pipeline = config.pop("components", {})
+        orig_pretraining = config.pop("pretraining", None)
+        config["components"] = {}
+        if auto_fill:
+            filled = registry.fill(config, validate=validate, schema=ConfigSchema)
+        else:
+            filled = config
+        filled["components"] = orig_pipeline
+        config["components"] = orig_pipeline
+        if orig_pretraining is not None:
+            filled["pretraining"] = orig_pretraining
+            config["pretraining"] = orig_pretraining
+        resolved_nlp = registry.resolve(
+            filled["nlp"], validate=validate, schema=ConfigSchemaNlp
+        )
+        create_tokenizer = resolved_nlp["tokenizer"]
+        create_vectors = resolved_nlp["vectors"]
+        before_creation = resolved_nlp["before_creation"]
+        after_creation = resolved_nlp["after_creation"]
+        after_pipeline_creation = resolved_nlp["after_pipeline_creation"]
+        lang_cls = cls
+        if before_creation is not None:
+            lang_cls = before_creation(cls)
+            if (
+                not isinstance(lang_cls, type)
+                or not issubclass(lang_cls, cls)
+                or lang_cls is not cls
+            ):
+                raise ValueError(Errors.E943.format(value=type(lang_cls)))
+
+        # Warn about require_gpu usage in jupyter notebook
+        warn_if_jupyter_cupy()
+
+        # Note that we don't load vectors here, instead they get loaded explicitly
+        # inside stuff like the spacy train function. If we loaded them here,
+        # then we would load them twice at runtime: once when we make from config,
+        # and then again when we load from disk.
+        nlp = lang_cls(
+            vocab=vocab,
+            create_tokenizer=create_tokenizer,
+            create_vectors=create_vectors,
+            meta=meta,
+        )
+        if after_creation is not None:
+            nlp = after_creation(nlp)
+            if not isinstance(nlp, cls):
+                raise ValueError(Errors.E942.format(name="creation", value=type(nlp)))
+        # To create the components we need to use the final interpolated config
+        # so all values are available (if component configs use variables).
+        # Later we replace the component config with the raw config again.
+        interpolated = filled.interpolate() if not filled.is_interpolated else filled
+        pipeline = interpolated.get("components", {})
+        # If components are loaded from a source (existing models), we cache
+        # them here so they're only loaded once
+        source_nlps = {}
+        source_nlp_vectors_hashes = {}
+        vocab_b = None
+        for pipe_name in config["nlp"]["pipeline"]:
+            if pipe_name not in pipeline:
+                opts = ", ".join(pipeline.keys())
+                raise ValueError(Errors.E956.format(name=pipe_name, opts=opts))
+            pipe_cfg = util.copy_config(pipeline[pipe_name])
+            raw_config = Config(filled["components"][pipe_name])
+            if pipe_name not in exclude:
+                if "factory" not in pipe_cfg and "source" not in pipe_cfg:
+                    err = Errors.E984.format(name=pipe_name, config=pipe_cfg)
+                    raise ValueError(err)
+                if "factory" in pipe_cfg:
+                    factory = pipe_cfg.pop("factory")
+                    # The pipe name (key in the config) here is the unique name
+                    # of the component, not necessarily the factory
+                    nlp.add_pipe(
+                        factory,
+                        name=pipe_name,
+                        config=pipe_cfg,
+                        validate=validate,
+                        raw_config=raw_config,
+                    )
+                else:
+                    assert "source" in pipe_cfg
+                    # We need the sourced components to reference the same
+                    # vocab without modifying the current vocab state **AND**
+                    # we still want to load the source model vectors to perform
+                    # the vectors check. Since the source vectors clobber the
+                    # current ones, we save the original vocab state and
+                    # restore after this loop. Existing strings are preserved
+                    # during deserialization, so they do not need any
+                    # additional handling.
+                    if vocab_b is None:
+                        vocab_b = nlp.vocab.to_bytes(exclude=["lookups", "strings"])
+                    model = pipe_cfg["source"]
+                    if model not in source_nlps:
+                        # Load with the same vocab, adding any strings
+                        source_nlps[model] = util.load_model(
+                            model, vocab=nlp.vocab, exclude=["lookups"]
+                        )
+                    source_name = pipe_cfg.get("component", pipe_name)
+                    listeners_replaced = False
+                    if "replace_listeners" in pipe_cfg:
+                        # Make sure that the listened-to component has the
+                        # state of the source pipeline listener map so that the
+                        # replace_listeners method below works as intended.
+                        source_nlps[model]._link_components()
+                        for name, proc in source_nlps[model].pipeline:
+                            if source_name in getattr(proc, "listening_components", []):
+                                source_nlps[model].replace_listeners(
+                                    name, source_name, pipe_cfg["replace_listeners"]
+                                )
+                                listeners_replaced = True
+                    with warnings.catch_warnings():
+                        warnings.filterwarnings("ignore", message="\\[W113\\]")
+                        nlp.add_pipe(
+                            source_name, source=source_nlps[model], name=pipe_name
+                        )
+                        # At this point after nlp.add_pipe, the listener map
+                        # corresponds to the new pipeline.
+                    if model not in source_nlp_vectors_hashes:
+                        source_nlp_vectors_hashes[model] = hash(
+                            source_nlps[model].vocab.vectors.to_bytes(
+                                exclude=["strings"]
+                            )
+                        )
+                    if "_sourced_vectors_hashes" not in nlp.meta:
+                        nlp.meta["_sourced_vectors_hashes"] = {}
+                    nlp.meta["_sourced_vectors_hashes"][
+                        pipe_name
+                    ] = source_nlp_vectors_hashes[model]
+                    # Delete from cache if listeners were replaced
+                    if listeners_replaced:
+                        del source_nlps[model]
+        # Restore the original vocab after sourcing if necessary
+        if vocab_b is not None:
+            nlp.vocab.from_bytes(vocab_b)
+
+        # Resolve disabled/enabled settings.
+        if isinstance(disable, str):
+            disable = [disable]
+        if isinstance(enable, str):
+            enable = [enable]
+        if isinstance(exclude, str):
+            exclude = [exclude]
+
+        # `enable` should not be merged with `enabled` (the opposite is true for `disable`/`disabled`). If the config
+        # specifies values for `enabled` not included in `enable`, emit warning.
+        if id(enable) != id(_DEFAULT_EMPTY_PIPES):
+            enabled = config["nlp"].get("enabled", [])
+            if len(enabled) and not set(enabled).issubset(enable):
+                warnings.warn(
+                    Warnings.W123.format(
+                        enable=enable,
+                        enabled=enabled,
+                    )
+                )
+
+        # Ensure sets of disabled/enabled pipe names are not contradictory.
+        disabled_pipes = cls._resolve_component_status(
+            list({*disable, *config["nlp"].get("disabled", [])}),
+            enable,
+            config["nlp"]["pipeline"],
+        )
+        nlp._disabled = set(p for p in disabled_pipes if p not in exclude)
+
+        nlp.batch_size = config["nlp"]["batch_size"]
+        nlp.config = filled if auto_fill else config
+        if after_pipeline_creation is not None:
+            nlp = after_pipeline_creation(nlp)
+            if not isinstance(nlp, cls):
+                raise ValueError(
+                    Errors.E942.format(name="pipeline_creation", value=type(nlp))
+                )
+        return nlp
+
+    def replace_listeners(
+        self,
+        tok2vec_name: str,
+        pipe_name: str,
+        listeners: Iterable[str],
+    ) -> None:
+        """Find listener layers (connecting to a token-to-vector embedding
+        component) of a given pipeline component model and replace
+        them with a standalone copy of the token-to-vector layer. This can be
+        useful when training a pipeline with components sourced from an existing
+        pipeline: if multiple components (e.g. tagger, parser, NER) listen to
+        the same tok2vec component, but some of them are frozen and not updated,
+        their performance may degrade significantly as the tok2vec component is
+        updated with new data. To prevent this, listeners can be replaced with
+        a standalone tok2vec layer that is owned by the component and doesn't
+        change if the component isn't updated.
+
+        tok2vec_name (str): Name of the token-to-vector component, typically
+            "tok2vec" or "transformer".
+        pipe_name (str): Name of pipeline component to replace listeners for.
+        listeners (Iterable[str]): The paths to the listeners, relative to the
+            component config, e.g. ["model.tok2vec"]. Typically, implementations
+            will only connect to one tok2vec component, [model.tok2vec], but in
+            theory, custom models can use multiple listeners. The value here can
+            either be an empty list to not replace any listeners, or a complete
+            (!) list of the paths to all listener layers used by the model.
+
+        DOCS: https://spacy.io/api/language#replace_listeners
+        """
+        if tok2vec_name not in self.pipe_names:
+            err = Errors.E889.format(
+                tok2vec=tok2vec_name,
+                name=pipe_name,
+                unknown=tok2vec_name,
+                opts=", ".join(self.pipe_names),
+            )
+            raise ValueError(err)
+        if pipe_name not in self.pipe_names:
+            err = Errors.E889.format(
+                tok2vec=tok2vec_name,
+                name=pipe_name,
+                unknown=pipe_name,
+                opts=", ".join(self.pipe_names),
+            )
+            raise ValueError(err)
+        tok2vec = self.get_pipe(tok2vec_name)
+        tok2vec_cfg = self.get_pipe_config(tok2vec_name)
+        if not isinstance(tok2vec, ty.ListenedToComponent):
+            raise ValueError(Errors.E888.format(name=tok2vec_name, pipe=type(tok2vec)))
+        tok2vec_model = tok2vec.model
+        pipe_listeners = tok2vec.listener_map.get(pipe_name, [])
+        pipe = self.get_pipe(pipe_name)
+        pipe_cfg = self._pipe_configs[pipe_name]
+        if listeners:
+            util.logger.debug("Replacing listeners of component '%s'", pipe_name)
+            if len(list(listeners)) != len(pipe_listeners):
+                # The number of listeners defined in the component model doesn't
+                # match the listeners to replace, so we won't be able to update
+                # the nodes and generate a matching config
+                err = Errors.E887.format(
+                    name=pipe_name,
+                    tok2vec=tok2vec_name,
+                    paths=listeners,
+                    n_listeners=len(pipe_listeners),
+                )
+                raise ValueError(err)
+            # Update the config accordingly by copying the tok2vec model to all
+            # sections defined in the listener paths
+            for listener_path in listeners:
+                # Check if the path actually exists in the config
+                try:
+                    util.dot_to_object(pipe_cfg, listener_path)
+                except KeyError:
+                    err = Errors.E886.format(
+                        name=pipe_name, tok2vec=tok2vec_name, path=listener_path
+                    )
+                    raise ValueError(err)
+                new_config = tok2vec_cfg["model"]
+                if "replace_listener_cfg" in tok2vec_model.attrs:
+                    replace_func = tok2vec_model.attrs["replace_listener_cfg"]
+                    new_config = replace_func(
+                        tok2vec_cfg["model"], pipe_cfg["model"]["tok2vec"]
+                    )
+                util.set_dot_to_object(pipe_cfg, listener_path, new_config)
+            # Go over the listener layers and replace them
+            for listener in pipe_listeners:
+                new_model = tok2vec_model.copy()
+                replace_listener_func = tok2vec_model.attrs.get("replace_listener")
+                if replace_listener_func is not None:
+                    # Pass the extra args to the callback without breaking compatibility with
+                    # old library versions that only expect a single parameter.
+                    num_params = len(
+                        inspect.signature(replace_listener_func).parameters
+                    )
+                    if num_params == 1:
+                        new_model = replace_listener_func(new_model)
+                    elif num_params == 3:
+                        new_model = replace_listener_func(new_model, listener, tok2vec)
+                    else:
+                        raise ValueError(Errors.E1055.format(num_params=num_params))
+
+                util.replace_model_node(pipe.model, listener, new_model)  # type: ignore[attr-defined]
+                tok2vec.remove_listener(listener, pipe_name)
+
+    @contextmanager
+    def memory_zone(self, mem: Optional[Pool] = None) -> Iterator[Pool]:
+        """Begin a block where all resources allocated during the block will
+        be freed at the end of it. If a resources was created within the
+        memory zone block, accessing it outside the block is invalid.
+        Behaviour of this invalid access is undefined. Memory zones should
+        not be nested.
+
+        The memory zone is helpful for services that need to process large
+        volumes of text with a defined memory budget.
+
+        Example
+        -------
+        >>> with nlp.memory_zone():
+        ...     for doc in nlp.pipe(texts):
+        ...        process_my_doc(doc)
+        >>> # use_doc(doc) <-- Invalid: doc was allocated in the memory zone
+        """
+        if mem is None:
+            mem = Pool()
+        # The ExitStack allows programmatic nested context managers.
+        # We don't know how many we need, so it would be awkward to have
+        # them as nested blocks.
+        with ExitStack() as stack:
+            contexts = [stack.enter_context(self.vocab.memory_zone(mem))]
+            if hasattr(self.tokenizer, "memory_zone"):
+                contexts.append(stack.enter_context(self.tokenizer.memory_zone(mem)))
+            for _, pipe in self.pipeline:
+                if hasattr(pipe, "memory_zone"):
+                    contexts.append(stack.enter_context(pipe.memory_zone(mem)))
+            yield mem
+
+    def to_disk(
+        self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
+    ) -> None:
+        """Save the current state to a directory.  If a model is loaded, this
+        will include the model.
+
+        path (str / Path): Path to a directory, which will be created if
+            it doesn't exist.
+        exclude (Iterable[str]): Names of components or serialization fields to exclude.
+
+        DOCS: https://spacy.io/api/language#to_disk
+        """
+        path = util.ensure_path(path)
+        serializers = {}
+        serializers["tokenizer"] = lambda p: self.tokenizer.to_disk(  # type: ignore[union-attr]
+            p, exclude=["vocab"]
+        )
+        serializers["meta.json"] = lambda p: srsly.write_json(
+            p, _replace_numpy_floats(self.meta)
+        )
+        serializers["config.cfg"] = lambda p: self.config.to_disk(p)
+        for name, proc in self._components:
+            if name in exclude:
+                continue
+            if not hasattr(proc, "to_disk"):
+                continue
+            serializers[name] = lambda p, proc=proc: proc.to_disk(p, exclude=["vocab"])  # type: ignore[misc]
+        serializers["vocab"] = lambda p: self.vocab.to_disk(p, exclude=exclude)
+        util.to_disk(path, serializers, exclude)
+
+    @staticmethod
+    def _resolve_component_status(
+        disable: Union[str, Iterable[str]],
+        enable: Union[str, Iterable[str]],
+        pipe_names: Iterable[str],
+    ) -> Tuple[str, ...]:
+        """Derives whether (1) `disable` and `enable` values are consistent and (2)
+        resolves those to a single set of disabled components. Raises an error in
+        case of inconsistency.
+
+        disable (Union[str, Iterable[str]]): Name(s) of component(s) or serialization fields to disable.
+        enable (Union[str, Iterable[str]]): Name(s) of pipeline component(s) to enable.
+        pipe_names (Iterable[str]): Names of all pipeline components.
+
+        RETURNS (Tuple[str, ...]): Names of components to exclude from pipeline w.r.t.
+                                   specified includes and excludes.
+        """
+
+        if isinstance(disable, str):
+            disable = [disable]
+        to_disable = disable
+
+        if enable:
+            if isinstance(enable, str):
+                enable = [enable]
+            to_disable = {
+                *[pipe_name for pipe_name in pipe_names if pipe_name not in enable],
+                *disable,
+            }
+            # If any pipe to be enabled is in to_disable, the specification is inconsistent.
+            if len(set(enable) & to_disable):
+                raise ValueError(Errors.E1042.format(enable=enable, disable=disable))
+
+        return tuple(to_disable)
+
+    def from_disk(
+        self,
+        path: Union[str, Path],
+        *,
+        exclude: Iterable[str] = SimpleFrozenList(),
+        overrides: Dict[str, Any] = SimpleFrozenDict(),
+    ) -> "Language":
+        """Loads state from a directory. Modifies the object in place and
+        returns it. If the saved `Language` object contains a model, the
+        model will be loaded.
+
+        path (str / Path): A path to a directory.
+        exclude (Iterable[str]): Names of components or serialization fields to exclude.
+        RETURNS (Language): The modified `Language` object.
+
+        DOCS: https://spacy.io/api/language#from_disk
+        """
+
+        def deserialize_meta(path: Path) -> None:
+            if path.exists():
+                data = srsly.read_json(path)
+                self.meta.update(data)
+                # self.meta always overrides meta["vectors"] with the metadata
+                # from self.vocab.vectors, so set the name directly
+                self.vocab.vectors.name = data.get("vectors", {}).get("name")
+
+        def deserialize_vocab(path: Path) -> None:
+            if path.exists():
+                self.vocab.from_disk(path, exclude=exclude)
+
+        path = util.ensure_path(path)
+        deserializers = {}
+        if Path(path / "config.cfg").exists():  # type: ignore[operator]
+            deserializers["config.cfg"] = lambda p: self.config.from_disk(
+                p, interpolate=False, overrides=overrides
+            )
+        deserializers["meta.json"] = deserialize_meta  # type: ignore[assignment]
+        deserializers["vocab"] = deserialize_vocab  # type: ignore[assignment]
+        deserializers["tokenizer"] = lambda p: self.tokenizer.from_disk(  # type: ignore[union-attr]
+            p, exclude=["vocab"]
+        )
+        for name, proc in self._components:
+            if name in exclude:
+                continue
+            if not hasattr(proc, "from_disk"):
+                continue
+            deserializers[name] = lambda p, proc=proc: proc.from_disk(  # type: ignore[misc]
+                p, exclude=["vocab"]
+            )
+        if not (path / "vocab").exists() and "vocab" not in exclude:  # type: ignore[operator]
+            # Convert to list here in case exclude is (default) tuple
+            exclude = list(exclude) + ["vocab"]
+        util.from_disk(path, deserializers, exclude)  # type: ignore[arg-type]
+        self._path = path  # type: ignore[assignment]
+        self._link_components()
+        return self
+
+    def to_bytes(self, *, exclude: Iterable[str] = SimpleFrozenList()) -> bytes:
+        """Serialize the current state to a binary string.
+
+        exclude (Iterable[str]): Names of components or serialization fields to exclude.
+        RETURNS (bytes): The serialized form of the `Language` object.
+
+        DOCS: https://spacy.io/api/language#to_bytes
+        """
+        serializers: Dict[str, Callable[[], bytes]] = {}
+        serializers["vocab"] = lambda: self.vocab.to_bytes(exclude=exclude)
+        serializers["tokenizer"] = lambda: self.tokenizer.to_bytes(exclude=["vocab"])  # type: ignore[union-attr]
+        serializers["meta.json"] = lambda: srsly.json_dumps(
+            _replace_numpy_floats(self.meta)
+        )
+        serializers["config.cfg"] = lambda: self.config.to_bytes()
+        for name, proc in self._components:
+            if name in exclude:
+                continue
+            if not hasattr(proc, "to_bytes"):
+                continue
+            serializers[name] = lambda proc=proc: proc.to_bytes(exclude=["vocab"])  # type: ignore[misc]
+        return util.to_bytes(serializers, exclude)
+
+    def from_bytes(
+        self, bytes_data: bytes, *, exclude: Iterable[str] = SimpleFrozenList()
+    ) -> "Language":
+        """Load state from a binary string.
+
+        bytes_data (bytes): The data to load from.
+        exclude (Iterable[str]): Names of components or serialization fields to exclude.
+        RETURNS (Language): The `Language` object.
+
+        DOCS: https://spacy.io/api/language#from_bytes
+        """
+
+        def deserialize_meta(b):
+            data = srsly.json_loads(b)
+            self.meta.update(data)
+            # self.meta always overrides meta["vectors"] with the metadata
+            # from self.vocab.vectors, so set the name directly
+            self.vocab.vectors.name = data.get("vectors", {}).get("name")
+
+        deserializers: Dict[str, Callable[[bytes], Any]] = {}
+        deserializers["config.cfg"] = lambda b: self.config.from_bytes(
+            b, interpolate=False
+        )
+        deserializers["meta.json"] = deserialize_meta
+        deserializers["vocab"] = lambda b: self.vocab.from_bytes(b, exclude=exclude)
+        deserializers["tokenizer"] = lambda b: self.tokenizer.from_bytes(  # type: ignore[union-attr]
+            b, exclude=["vocab"]
+        )
+        for name, proc in self._components:
+            if name in exclude:
+                continue
+            if not hasattr(proc, "from_bytes"):
+                continue
+            deserializers[name] = lambda b, proc=proc: proc.from_bytes(  # type: ignore[misc]
+                b, exclude=["vocab"]
+            )
+        util.from_bytes(bytes_data, deserializers, exclude)
+        self._link_components()
+        return self
+
+

A text-processing pipeline. Usually you'll load this once per process, +and pass the instance around your application.

+

Defaults (class): Settings, data and factory methods for creating the nlp +object and processing pipeline. +lang (str): IETF language code, such as 'en'.

+

DOCS: https://spacy.io/api/language

+

Initialise a Language object.

+

vocab (Vocab): A Vocab object. If True, a vocab is created. +meta (dict): Custom meta data for the Language class. Is written to by +models to add model meta data. +max_length (int): Maximum number of characters in a single text. The +current models may run out memory on extremely long texts, due to +large internal allocations. You should segment these texts into +meaningful units, e.g. paragraphs, subsections etc, before passing +them to spaCy. Default maximum length is 1,000,000 charas (1mb). As +a rule of thumb, if all pipeline components are enabled, spaCy's +default models currently requires roughly 1GB of temporary memory per +100,000 characters in one text. +create_tokenizer (Callable): Function that takes the nlp object and +returns a tokenizer. +batch_size (int): Default batch size for pipe and evaluate.

+

DOCS: https://spacy.io/api/language#init

+

Subclasses

+
    +
  • spacy.lang.de.German
  • +
  • spacy.lang.ja.Japanese
  • +
  • spacy.lang.xx.MultiLanguage
  • +
+

Class variables

+
+
var Defaults
+
+

Language data defaults, available via Language.Defaults. Can be +overwritten by language subclasses by defining their own subclasses of +Language.Defaults.

+
+
var default_config
+
+
+
+
var factories
+
+
+
+
var lang : str | None
+
+
+
+
+

Static methods

+
+
+def component(name: str,
*,
assigns: Iterable[str] = [],
requires: Iterable[str] = [],
retokenizes: bool = False,
func: Callable[[spacy.tokens.doc.Doc], spacy.tokens.doc.Doc] | None = None) ‑> Callable[..., Any]
+
+
+

Register a new pipeline component. Can be used for stateless function +components that don't require a separate factory. Can be used as a +decorator on a function or classmethod, or called as a function with the +factory provided as the func keyword argument. To create a component and +add it to the pipeline, you can use nlp.add_pipe(name).

+

name (str): The name of the component factory. +assigns (Iterable[str]): Doc/Token attributes assigned by this component, +e.g. "token.ent_id". Used for pipeline analysis. +requires (Iterable[str]): Doc/Token attributes required by this component, +e.g. "token.ent_id". Used for pipeline analysis. +retokenizes (bool): Whether the component changes the tokenization. +Used for pipeline analysis. +func (Optional[Callable[[Doc], Doc]): Factory function if not used as a decorator.

+

DOCS: https://spacy.io/api/language#component

+
+
+def factory(name: str,
*,
default_config: Dict[str, Any] = {},
assigns: Iterable[str] = [],
requires: Iterable[str] = [],
retokenizes: bool = False,
default_score_weights: Dict[str, float | None] = {},
func: Callable | None = None) ‑> Callable
+
+
+

Register a new pipeline component factory. Can be used as a decorator +on a function or classmethod, or called as a function with the factory +provided as the func keyword argument. To create a component and add +it to the pipeline, you can use nlp.add_pipe(name).

+

name (str): The name of the component factory. +default_config (Dict[str, Any]): Default configuration, describing the +default values of the factory arguments. +assigns (Iterable[str]): Doc/Token attributes assigned by this component, +e.g. "token.ent_id". Used for pipeline analysis. +requires (Iterable[str]): Doc/Token attributes required by this component, +e.g. "token.ent_id". Used for pipeline analysis. +retokenizes (bool): Whether the component changes the tokenization. +Used for pipeline analysis. +default_score_weights (Dict[str, Optional[float]]): The scores to report during +training, and their default weight towards the final score used to +select the best model. Weights should sum to 1.0 per component and +will be combined and normalized for the whole pipeline. If None, +the score won't be shown in the logs or be weighted. +func (Optional[Callable]): Factory function if not used as a decorator.

+

DOCS: https://spacy.io/api/language#factory

+
+
+def from_config(config: Dict[str, Any] | confection.Config = {},
*,
vocab: spacy.vocab.Vocab | bool = True,
disable: str | Iterable[str] = [],
enable: str | Iterable[str] = [],
exclude: str | Iterable[str] = [],
meta: Dict[str, Any] = {},
auto_fill: bool = True,
validate: bool = True) ‑> spacy.language.Language
+
+
+

Create the nlp object from a loaded config. Will set up the tokenizer +and language data, add pipeline components etc. If no config is provided, +the default config of the given language is used.

+

config (Dict[str, Any] / Config): The loaded config. +vocab (Vocab): A Vocab object. If True, a vocab is created. +disable (Union[str, Iterable[str]]): Name(s) of pipeline component(s) to disable. +Disabled pipes will be loaded but they won't be run unless you +explicitly enable them by calling nlp.enable_pipe. +enable (Union[str, Iterable[str]]): Name(s) of pipeline component(s) to enable. All other +pipes will be disabled (and can be enabled using nlp.enable_pipe). +exclude (Union[str, Iterable[str]]): Name(s) of pipeline component(s) to exclude. +Excluded components won't be loaded. +meta (Dict[str, Any]): Meta overrides for nlp.meta. +auto_fill (bool): Automatically fill in missing values in config based +on defaults and function argument annotations. +validate (bool): Validate the component config and arguments against +the types expected by the factory. +RETURNS (Language): The initialized Language class.

+

DOCS: https://spacy.io/api/language#from_config

+
+
+def get_factory_meta(name: str) ‑> spacy.language.FactoryMeta +
+
+

Get the meta information for a given factory name.

+

name (str): The component factory name. +RETURNS (FactoryMeta): The meta for the given factory name.

+
+
+def get_factory_name(name: str) ‑> str +
+
+

Get the internal factory name based on the language subclass.

+

name (str): The factory name. +RETURNS (str): The internal factory name.

+
+
+def has_factory(name: str) ‑> bool +
+
+

RETURNS (bool): Whether a factory of that name is registered.

+
+
+def set_factory_meta(name: str, value: FactoryMeta) ‑> None +
+
+

Set the meta information for a given factory name.

+

name (str): The component factory name. +value (FactoryMeta): The meta to set.

+
+
+

Instance variables

+
+
prop component_names : List[str]
+
+
+ +Expand source code + +
@property
+def component_names(self) -> List[str]:
+    """Get the names of the available pipeline components. Includes all
+    active and inactive pipeline components.
+
+    RETURNS (List[str]): List of component name strings, in order.
+    """
+    names = [pipe_name for pipe_name, _ in self._components]
+    return SimpleFrozenList(names, error=Errors.E926.format(attr="component_names"))
+
+

Get the names of the available pipeline components. Includes all +active and inactive pipeline components.

+

RETURNS (List[str]): List of component name strings, in order.

+
+
prop components : List[Tuple[str, Callable[[spacy.tokens.doc.Doc], spacy.tokens.doc.Doc]]]
+
+
+ +Expand source code + +
@property
+def components(self) -> List[Tuple[str, PipeCallable]]:
+    """Get all (name, component) tuples in the pipeline, including the
+    currently disabled components.
+    """
+    return SimpleFrozenList(
+        self._components, error=Errors.E926.format(attr="components")
+    )
+
+

Get all (name, component) tuples in the pipeline, including the +currently disabled components.

+
+
prop config : confection.Config
+
+
+ +Expand source code + +
@property
+def config(self) -> Config:
+    """Trainable config for the current language instance. Includes the
+    current pipeline components, as well as default training config.
+
+    RETURNS (thinc.api.Config): The config.
+
+    DOCS: https://spacy.io/api/language#config
+    """
+    self._config.setdefault("nlp", {})
+    self._config.setdefault("training", {})
+    self._config["nlp"]["lang"] = self.lang
+    # We're storing the filled config for each pipeline component and so
+    # we can populate the config again later
+    pipeline = {}
+    score_weights = []
+    for pipe_name in self.component_names:
+        pipe_meta = self.get_pipe_meta(pipe_name)
+        pipe_config = self.get_pipe_config(pipe_name)
+        pipeline[pipe_name] = {"factory": pipe_meta.factory, **pipe_config}
+        if pipe_meta.default_score_weights:
+            score_weights.append(pipe_meta.default_score_weights)
+    self._config["nlp"]["pipeline"] = list(self.component_names)
+    self._config["nlp"]["disabled"] = list(self.disabled)
+    self._config["components"] = pipeline
+    # We're merging the existing score weights back into the combined
+    # weights to make sure we're preserving custom settings in the config
+    # but also reflect updates (e.g. new components added)
+    prev_weights = self._config["training"].get("score_weights", {})
+    combined_score_weights = combine_score_weights(score_weights, prev_weights)
+    self._config["training"]["score_weights"] = combined_score_weights
+    if not srsly.is_json_serializable(self._config):
+        raise ValueError(Errors.E961.format(config=self._config))
+    return self._config
+
+

Trainable config for the current language instance. Includes the +current pipeline components, as well as default training config.

+

RETURNS (thinc.api.Config): The config.

+

DOCS: https://spacy.io/api/language#config

+
+
prop disabled : List[str]
+
+
+ +Expand source code + +
@property
+def disabled(self) -> List[str]:
+    """Get the names of all disabled components.
+
+    RETURNS (List[str]): The disabled components.
+    """
+    # Make sure the disabled components are returned in the order they
+    # appear in the pipeline (which isn't guaranteed by the set)
+    names = [name for name, _ in self._components if name in self._disabled]
+    return SimpleFrozenList(names, error=Errors.E926.format(attr="disabled"))
+
+

Get the names of all disabled components.

+

RETURNS (List[str]): The disabled components.

+
+
prop factory_names : List[str]
+
+
+ +Expand source code + +
@property
+def factory_names(self) -> List[str]:
+    """Get names of all available factories.
+
+    RETURNS (List[str]): The factory names.
+    """
+    names = list(self.factories.keys())
+    return SimpleFrozenList(names)
+
+

Get names of all available factories.

+

RETURNS (List[str]): The factory names.

+
+
prop meta : Dict[str, Any]
+
+
+ +Expand source code + +
@property
+def meta(self) -> Dict[str, Any]:
+    """Custom meta data of the language class. If a model is loaded, this
+    includes details from the model's meta.json.
+
+    RETURNS (Dict[str, Any]): The meta.
+
+    DOCS: https://spacy.io/api/language#meta
+    """
+    spacy_version = util.get_minor_version_range(about.__version__)
+    if self.vocab.lang:
+        self._meta.setdefault("lang", self.vocab.lang)
+    else:
+        self._meta.setdefault("lang", self.lang)
+    self._meta.setdefault("name", "pipeline")
+    self._meta.setdefault("version", "0.0.0")
+    self._meta.setdefault("spacy_version", spacy_version)
+    self._meta.setdefault("description", "")
+    self._meta.setdefault("author", "")
+    self._meta.setdefault("email", "")
+    self._meta.setdefault("url", "")
+    self._meta.setdefault("license", "")
+    self._meta.setdefault("spacy_git_version", GIT_VERSION)
+    self._meta["vectors"] = {
+        "width": self.vocab.vectors_length,
+        "vectors": len(self.vocab.vectors),
+        "keys": self.vocab.vectors.n_keys,
+        "name": self.vocab.vectors.name,
+        "mode": self.vocab.vectors.mode,
+    }
+    self._meta["labels"] = dict(self.pipe_labels)
+    # TODO: Adding this back to prevent breaking people's code etc., but
+    # we should consider removing it
+    self._meta["pipeline"] = list(self.pipe_names)
+    self._meta["components"] = list(self.component_names)
+    self._meta["disabled"] = list(self.disabled)
+    return self._meta
+
+

Custom meta data of the language class. If a model is loaded, this +includes details from the model's meta.json.

+

RETURNS (Dict[str, Any]): The meta.

+

DOCS: https://spacy.io/api/language#meta

+
+
prop path
+
+
+ +Expand source code + +
@property
+def path(self):
+    return self._path
+
+
+
+
prop pipe_factories : Dict[str, str]
+
+
+ +Expand source code + +
@property
+def pipe_factories(self) -> Dict[str, str]:
+    """Get the component factories for the available pipeline components.
+
+    RETURNS (Dict[str, str]): Factory names, keyed by component names.
+    """
+    factories = {}
+    for pipe_name, pipe in self._components:
+        factories[pipe_name] = self.get_pipe_meta(pipe_name).factory
+    return SimpleFrozenDict(factories)
+
+

Get the component factories for the available pipeline components.

+

RETURNS (Dict[str, str]): Factory names, keyed by component names.

+
+
prop pipe_labels : Dict[str, List[str]]
+
+
+ +Expand source code + +
@property
+def pipe_labels(self) -> Dict[str, List[str]]:
+    """Get the labels set by the pipeline components, if available (if
+    the component exposes a labels property and the labels are not
+    hidden).
+
+    RETURNS (Dict[str, List[str]]): Labels keyed by component name.
+    """
+    labels = {}
+    for name, pipe in self._components:
+        if hasattr(pipe, "hide_labels") and pipe.hide_labels is True:
+            continue
+        if hasattr(pipe, "labels"):
+            labels[name] = list(pipe.labels)
+    return SimpleFrozenDict(labels)
+
+

Get the labels set by the pipeline components, if available (if +the component exposes a labels property and the labels are not +hidden).

+

RETURNS (Dict[str, List[str]]): Labels keyed by component name.

+
+
prop pipe_names : List[str]
+
+
+ +Expand source code + +
@property
+def pipe_names(self) -> List[str]:
+    """Get names of available active pipeline components.
+
+    RETURNS (List[str]): List of component name strings, in order.
+    """
+    names = [pipe_name for pipe_name, _ in self.pipeline]
+    return SimpleFrozenList(names, error=Errors.E926.format(attr="pipe_names"))
+
+

Get names of available active pipeline components.

+

RETURNS (List[str]): List of component name strings, in order.

+
+
prop pipeline : List[Tuple[str, Callable[[spacy.tokens.doc.Doc], spacy.tokens.doc.Doc]]]
+
+
+ +Expand source code + +
@property
+def pipeline(self) -> List[Tuple[str, PipeCallable]]:
+    """The processing pipeline consisting of (name, component) tuples. The
+    components are called on the Doc in order as it passes through the
+    pipeline.
+
+    RETURNS (List[Tuple[str, Callable[[Doc], Doc]]]): The pipeline.
+    """
+    pipes = [(n, p) for n, p in self._components if n not in self._disabled]
+    return SimpleFrozenList(pipes, error=Errors.E926.format(attr="pipeline"))
+
+

The processing pipeline consisting of (name, component) tuples. The +components are called on the Doc in order as it passes through the +pipeline.

+

RETURNS (List[Tuple[str, Callable[[Doc], Doc]]]): The pipeline.

+
+
+

Methods

+
+
+def add_pipe(self,
factory_name: str,
name: str | None = None,
*,
before: str | int | None = None,
after: str | int | None = None,
first: bool | None = None,
last: bool | None = None,
source: ForwardRef('Language') | None = None,
config: Dict[str, Any] = {},
raw_config: confection.Config | None = None,
validate: bool = True) ‑> Callable[[spacy.tokens.doc.Doc], spacy.tokens.doc.Doc]
+
+
+
+ +Expand source code + +
def add_pipe(
+    self,
+    factory_name: str,
+    name: Optional[str] = None,
+    *,
+    before: Optional[Union[str, int]] = None,
+    after: Optional[Union[str, int]] = None,
+    first: Optional[bool] = None,
+    last: Optional[bool] = None,
+    source: Optional["Language"] = None,
+    config: Dict[str, Any] = SimpleFrozenDict(),
+    raw_config: Optional[Config] = None,
+    validate: bool = True,
+) -> PipeCallable:
+    """Add a component to the processing pipeline. Valid components are
+    callables that take a `Doc` object, modify it and return it. Only one
+    of before/after/first/last can be set. Default behaviour is "last".
+
+    factory_name (str): Name of the component factory.
+    name (str): Name of pipeline component. Overwrites existing
+        component.name attribute if available. If no name is set and
+        the component exposes no name attribute, component.__name__ is
+        used. An error is raised if a name already exists in the pipeline.
+    before (Union[str, int]): Name or index of the component to insert new
+        component directly before.
+    after (Union[str, int]): Name or index of the component to insert new
+        component directly after.
+    first (bool): If True, insert component first in the pipeline.
+    last (bool): If True, insert component last in the pipeline.
+    source (Language): Optional loaded nlp object to copy the pipeline
+        component from.
+    config (Dict[str, Any]): Config parameters to use for this component.
+        Will be merged with default config, if available.
+    raw_config (Optional[Config]): Internals: the non-interpolated config.
+    validate (bool): Whether to validate the component config against the
+        arguments and types expected by the factory.
+    RETURNS (Callable[[Doc], Doc]): The pipeline component.
+
+    DOCS: https://spacy.io/api/language#add_pipe
+    """
+    if not isinstance(factory_name, str):
+        bad_val = repr(factory_name)
+        err = Errors.E966.format(component=bad_val, name=name)
+        raise ValueError(err)
+    name = name if name is not None else factory_name
+    if name in self.component_names:
+        raise ValueError(Errors.E007.format(name=name, opts=self.component_names))
+    # Overriding pipe name in the config is not supported and will be ignored.
+    if "name" in config:
+        warnings.warn(Warnings.W119.format(name_in_config=config.pop("name")))
+    if source is not None:
+        # We're loading the component from a model. After loading the
+        # component, we know its real factory name
+        pipe_component, factory_name = self.create_pipe_from_source(
+            factory_name, source, name=name
+        )
+    else:
+        pipe_component = self.create_pipe(
+            factory_name,
+            name=name,
+            config=config,
+            raw_config=raw_config,
+            validate=validate,
+        )
+    pipe_index = self._get_pipe_index(before, after, first, last)
+    self._pipe_meta[name] = self.get_factory_meta(factory_name)
+    self._components.insert(pipe_index, (name, pipe_component))
+    self._link_components()
+    return pipe_component
+
+

Add a component to the processing pipeline. Valid components are +callables that take a Doc object, modify it and return it. Only one +of before/after/first/last can be set. Default behaviour is "last".

+

factory_name (str): Name of the component factory. +name (str): Name of pipeline component. Overwrites existing +component.name attribute if available. If no name is set and +the component exposes no name attribute, component.name is +used. An error is raised if a name already exists in the pipeline. +before (Union[str, int]): Name or index of the component to insert new +component directly before. +after (Union[str, int]): Name or index of the component to insert new +component directly after. +first (bool): If True, insert component first in the pipeline. +last (bool): If True, insert component last in the pipeline. +source (Language): Optional loaded nlp object to copy the pipeline +component from. +config (Dict[str, Any]): Config parameters to use for this component. +Will be merged with default config, if available. +raw_config (Optional[Config]): Internals: the non-interpolated config. +validate (bool): Whether to validate the component config against the +arguments and types expected by the factory. +RETURNS (Callable[[Doc], Doc]): The pipeline component.

+

DOCS: https://spacy.io/api/language#add_pipe

+
+
+def analyze_pipes(self,
*,
keys: List[str] = ['assigns', 'requires', 'scores', 'retokenizes'],
pretty: bool = False) ‑> Dict[str, Any] | None
+
+
+
+ +Expand source code + +
def analyze_pipes(
+    self,
+    *,
+    keys: List[str] = ["assigns", "requires", "scores", "retokenizes"],
+    pretty: bool = False,
+) -> Optional[Dict[str, Any]]:
+    """Analyze the current pipeline components, print a summary of what
+    they assign or require and check that all requirements are met.
+
+    keys (List[str]): The meta values to display in the table. Corresponds
+        to values in FactoryMeta, defined by @Language.factory decorator.
+    pretty (bool): Pretty-print the results.
+    RETURNS (dict): The data.
+    """
+    analysis = analyze_pipes(self, keys=keys)
+    if pretty:
+        print_pipe_analysis(analysis, keys=keys)
+    return analysis
+
+

Analyze the current pipeline components, print a summary of what +they assign or require and check that all requirements are met.

+

keys (List[str]): The meta values to display in the table. Corresponds +to values in FactoryMeta, defined by @Language.factory decorator. +pretty (bool): Pretty-print the results. +RETURNS (dict): The data.

+
+
+def begin_training(self,
get_examples: Callable[[], Iterable[spacy.training.example.Example]] | None = None,
*,
sgd: thinc.optimizers.Optimizer | None = None) ‑> thinc.optimizers.Optimizer
+
+
+
+ +Expand source code + +
def begin_training(
+    self,
+    get_examples: Optional[Callable[[], Iterable[Example]]] = None,
+    *,
+    sgd: Optional[Optimizer] = None,
+) -> Optimizer:
+    warnings.warn(Warnings.W089, DeprecationWarning)
+    return self.initialize(get_examples, sgd=sgd)
+
+
+
+
+def create_optimizer(self) +
+
+
+ +Expand source code + +
def create_optimizer(self):
+    """Create an optimizer, usually using the [training.optimizer] config."""
+    subconfig = {"optimizer": self.config["training"]["optimizer"]}
+    return registry.resolve(subconfig)["optimizer"]
+
+

Create an optimizer, usually using the [training.optimizer] config.

+
+
+def create_pipe(self,
factory_name: str,
name: str | None = None,
*,
config: Dict[str, Any] = {},
raw_config: confection.Config | None = None,
validate: bool = True) ‑> Callable[[spacy.tokens.doc.Doc], spacy.tokens.doc.Doc]
+
+
+
+ +Expand source code + +
def create_pipe(
+    self,
+    factory_name: str,
+    name: Optional[str] = None,
+    *,
+    config: Dict[str, Any] = SimpleFrozenDict(),
+    raw_config: Optional[Config] = None,
+    validate: bool = True,
+) -> PipeCallable:
+    """Create a pipeline component. Mostly used internally. To create and
+    add a component to the pipeline, you can use nlp.add_pipe.
+
+    factory_name (str): Name of component factory.
+    name (Optional[str]): Optional name to assign to component instance.
+        Defaults to factory name if not set.
+    config (Dict[str, Any]): Config parameters to use for this component.
+        Will be merged with default config, if available.
+    raw_config (Optional[Config]): Internals: the non-interpolated config.
+    validate (bool): Whether to validate the component config against the
+        arguments and types expected by the factory.
+    RETURNS (Callable[[Doc], Doc]): The pipeline component.
+
+    DOCS: https://spacy.io/api/language#create_pipe
+    """
+    name = name if name is not None else factory_name
+    if not isinstance(config, dict):
+        err = Errors.E962.format(style="config", name=name, cfg_type=type(config))
+        raise ValueError(err)
+    if not srsly.is_json_serializable(config):
+        raise ValueError(Errors.E961.format(config=config))
+    if not self.has_factory(factory_name):
+        err = Errors.E002.format(
+            name=factory_name,
+            opts=", ".join(self.factory_names),
+            method="create_pipe",
+            lang=util.get_object_name(self),
+            lang_code=self.lang,
+        )
+        raise ValueError(err)
+    pipe_meta = self.get_factory_meta(factory_name)
+    # This is unideal, but the alternative would mean you always need to
+    # specify the full config settings, which is not really viable.
+    if pipe_meta.default_config:
+        config = Config(pipe_meta.default_config).merge(config)
+    internal_name = self.get_factory_name(factory_name)
+    # If the language-specific factory doesn't exist, try again with the
+    # not-specific name
+    if internal_name not in registry.factories:
+        internal_name = factory_name
+    # The name allows components to know their pipe name and use it in the
+    # losses etc. (even if multiple instances of the same factory are used)
+    config = {"nlp": self, "name": name, **config, "@factories": internal_name}
+    # We need to create a top-level key because Thinc doesn't allow resolving
+    # top-level references to registered functions. Also gives nicer errors.
+    cfg = {factory_name: config}
+    # We're calling the internal _fill here to avoid constructing the
+    # registered functions twice
+    resolved = registry.resolve(cfg, validate=validate)
+    filled = registry.fill({"cfg": cfg[factory_name]}, validate=validate)["cfg"]
+    filled = Config(filled)
+    filled["factory"] = factory_name
+    filled.pop("@factories", None)
+    # Remove the extra values we added because we don't want to keep passing
+    # them around, copying them etc.
+    filled.pop("nlp", None)
+    filled.pop("name", None)
+    # Merge the final filled config with the raw config (including non-
+    # interpolated variables)
+    if raw_config:
+        filled = filled.merge(raw_config)
+    self._pipe_configs[name] = filled
+    return resolved[factory_name]
+
+

Create a pipeline component. Mostly used internally. To create and +add a component to the pipeline, you can use nlp.add_pipe.

+

factory_name (str): Name of component factory. +name (Optional[str]): Optional name to assign to component instance. +Defaults to factory name if not set. +config (Dict[str, Any]): Config parameters to use for this component. +Will be merged with default config, if available. +raw_config (Optional[Config]): Internals: the non-interpolated config. +validate (bool): Whether to validate the component config against the +arguments and types expected by the factory. +RETURNS (Callable[[Doc], Doc]): The pipeline component.

+

DOCS: https://spacy.io/api/language#create_pipe

+
+
+def create_pipe_from_source(self,
source_name: str,
source: Language,
*,
name: str) ‑> Tuple[Callable[[spacy.tokens.doc.Doc], spacy.tokens.doc.Doc], str]
+
+
+
+ +Expand source code + +
def create_pipe_from_source(
+    self, source_name: str, source: "Language", *, name: str
+) -> Tuple[PipeCallable, str]:
+    """Create a pipeline component by copying it from an existing model.
+
+    source_name (str): Name of the component in the source pipeline.
+    source (Language): The source nlp object to copy from.
+    name (str): Optional alternative name to use in current pipeline.
+    RETURNS (Tuple[Callable[[Doc], Doc], str]): The component and its factory name.
+    """
+    # Check source type
+    if not isinstance(source, Language):
+        raise ValueError(Errors.E945.format(name=source_name, source=type(source)))
+    if self.vocab.vectors != source.vocab.vectors:
+        warnings.warn(Warnings.W113.format(name=source_name))
+    if source_name not in source.component_names:
+        raise KeyError(
+            Errors.E944.format(
+                name=source_name,
+                model=f"{source.meta['lang']}_{source.meta['name']}",
+                opts=", ".join(source.component_names),
+            )
+        )
+    pipe = source.get_pipe(source_name)
+    # There is no actual solution here. Either the component has the right
+    # name for the source pipeline or the component has the right name for
+    # the current pipeline. This prioritizes the current pipeline.
+    if hasattr(pipe, "name"):
+        pipe.name = name
+    # Make sure the source config is interpolated so we don't end up with
+    # orphaned variables in our final config
+    source_config = source.config.interpolate()
+    pipe_config = util.copy_config(source_config["components"][source_name])
+    self._pipe_configs[name] = pipe_config
+    if self.vocab.strings != source.vocab.strings:
+        for s in source.vocab.strings:
+            self.vocab.strings.add(s)
+    return pipe, pipe_config["factory"]
+
+

Create a pipeline component by copying it from an existing model.

+

source_name (str): Name of the component in the source pipeline. +source (Language): The source nlp object to copy from. +name (str): Optional alternative name to use in current pipeline. +RETURNS (Tuple[Callable[[Doc], Doc], str]): The component and its factory name.

+
+
+def disable_pipe(self, name: str) ‑> None +
+
+
+ +Expand source code + +
def disable_pipe(self, name: str) -> None:
+    """Disable a pipeline component. The component will still exist on
+    the nlp object, but it won't be run as part of the pipeline. Does
+    nothing if the component is already disabled.
+
+    name (str): The name of the component to disable.
+    """
+    if name not in self.component_names:
+        raise ValueError(Errors.E001.format(name=name, opts=self.component_names))
+    self._disabled.add(name)
+
+

Disable a pipeline component. The component will still exist on +the nlp object, but it won't be run as part of the pipeline. Does +nothing if the component is already disabled.

+

name (str): The name of the component to disable.

+
+
+def disable_pipes(self, *names) ‑> spacy.language.DisabledPipes +
+
+
+ +Expand source code + +
def disable_pipes(self, *names) -> "DisabledPipes":
+    """Disable one or more pipeline components. If used as a context
+    manager, the pipeline will be restored to the initial state at the end
+    of the block. Otherwise, a DisabledPipes object is returned, that has
+    a `.restore()` method you can use to undo your changes.
+
+    This method has been deprecated since 3.0
+    """
+    warnings.warn(Warnings.W096, DeprecationWarning)
+    if len(names) == 1 and isinstance(names[0], (list, tuple)):
+        names = names[0]  # type: ignore[assignment]    # support list of names instead of spread
+    return self.select_pipes(disable=names)
+
+

Disable one or more pipeline components. If used as a context +manager, the pipeline will be restored to the initial state at the end +of the block. Otherwise, a DisabledPipes object is returned, that has +a .restore() method you can use to undo your changes.

+

This method has been deprecated since 3.0

+
+
+def enable_pipe(self, name: str) ‑> None +
+
+
+ +Expand source code + +
def enable_pipe(self, name: str) -> None:
+    """Enable a previously disabled pipeline component so it's run as part
+    of the pipeline. Does nothing if the component is already enabled.
+
+    name (str): The name of the component to enable.
+    """
+    if name not in self.component_names:
+        raise ValueError(Errors.E001.format(name=name, opts=self.component_names))
+    if name in self.disabled:
+        self._disabled.remove(name)
+
+

Enable a previously disabled pipeline component so it's run as part +of the pipeline. Does nothing if the component is already enabled.

+

name (str): The name of the component to enable.

+
+
+def evaluate(self,
examples: Iterable[spacy.training.example.Example],
*,
batch_size: int | None = None,
scorer: spacy.scorer.Scorer | None = None,
component_cfg: Dict[str, Dict[str, Any]] | None = None,
scorer_cfg: Dict[str, Any] | None = None,
per_component: bool = False) ‑> Dict[str, Any]
+
+
+
+ +Expand source code + +
def evaluate(
+    self,
+    examples: Iterable[Example],
+    *,
+    batch_size: Optional[int] = None,
+    scorer: Optional[Scorer] = None,
+    component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
+    scorer_cfg: Optional[Dict[str, Any]] = None,
+    per_component: bool = False,
+) -> Dict[str, Any]:
+    """Evaluate a model's pipeline components.
+
+    examples (Iterable[Example]): `Example` objects.
+    batch_size (Optional[int]): Batch size to use.
+    scorer (Optional[Scorer]): Scorer to use. If not passed in, a new one
+        will be created.
+    component_cfg (dict): An optional dictionary with extra keyword
+        arguments for specific components.
+    scorer_cfg (dict): An optional dictionary with extra keyword arguments
+        for the scorer.
+    per_component (bool): Whether to return the scores keyed by component
+        name. Defaults to False.
+
+    RETURNS (Scorer): The scorer containing the evaluation results.
+
+    DOCS: https://spacy.io/api/language#evaluate
+    """
+    examples = list(examples)
+    validate_examples(examples, "Language.evaluate")
+    examples = _copy_examples(examples)
+    if batch_size is None:
+        batch_size = self.batch_size
+    if component_cfg is None:
+        component_cfg = {}
+    if scorer_cfg is None:
+        scorer_cfg = {}
+    if scorer is None:
+        kwargs = dict(scorer_cfg)
+        kwargs.setdefault("nlp", self)
+        scorer = Scorer(**kwargs)
+    # reset annotation in predicted docs and time tokenization
+    start_time = timer()
+    # this is purely for timing
+    for eg in examples:
+        self.make_doc(eg.reference.text)
+    # apply all pipeline components
+    docs = self.pipe(
+        (eg.predicted for eg in examples),
+        batch_size=batch_size,
+        component_cfg=component_cfg,
+    )
+    for eg, doc in zip(examples, docs):
+        eg.predicted = doc
+    end_time = timer()
+    results = scorer.score(examples, per_component=per_component)
+    n_words = sum(len(eg.predicted) for eg in examples)
+    results["speed"] = n_words / (end_time - start_time)
+    return _replace_numpy_floats(results)
+
+

Evaluate a model's pipeline components.

+

examples (Iterable[Example]): Example objects. +batch_size (Optional[int]): Batch size to use. +scorer (Optional[Scorer]): Scorer to use. If not passed in, a new one +will be created. +component_cfg (dict): An optional dictionary with extra keyword +arguments for specific components. +scorer_cfg (dict): An optional dictionary with extra keyword arguments +for the scorer. +per_component (bool): Whether to return the scores keyed by component +name. Defaults to False.

+

RETURNS (Scorer): The scorer containing the evaluation results.

+

DOCS: https://spacy.io/api/language#evaluate

+
+
+def from_bytes(self, bytes_data: bytes, *, exclude: Iterable[str] = []) ‑> spacy.language.Language +
+
+
+ +Expand source code + +
def from_bytes(
+    self, bytes_data: bytes, *, exclude: Iterable[str] = SimpleFrozenList()
+) -> "Language":
+    """Load state from a binary string.
+
+    bytes_data (bytes): The data to load from.
+    exclude (Iterable[str]): Names of components or serialization fields to exclude.
+    RETURNS (Language): The `Language` object.
+
+    DOCS: https://spacy.io/api/language#from_bytes
+    """
+
+    def deserialize_meta(b):
+        data = srsly.json_loads(b)
+        self.meta.update(data)
+        # self.meta always overrides meta["vectors"] with the metadata
+        # from self.vocab.vectors, so set the name directly
+        self.vocab.vectors.name = data.get("vectors", {}).get("name")
+
+    deserializers: Dict[str, Callable[[bytes], Any]] = {}
+    deserializers["config.cfg"] = lambda b: self.config.from_bytes(
+        b, interpolate=False
+    )
+    deserializers["meta.json"] = deserialize_meta
+    deserializers["vocab"] = lambda b: self.vocab.from_bytes(b, exclude=exclude)
+    deserializers["tokenizer"] = lambda b: self.tokenizer.from_bytes(  # type: ignore[union-attr]
+        b, exclude=["vocab"]
+    )
+    for name, proc in self._components:
+        if name in exclude:
+            continue
+        if not hasattr(proc, "from_bytes"):
+            continue
+        deserializers[name] = lambda b, proc=proc: proc.from_bytes(  # type: ignore[misc]
+            b, exclude=["vocab"]
+        )
+    util.from_bytes(bytes_data, deserializers, exclude)
+    self._link_components()
+    return self
+
+

Load state from a binary string.

+

bytes_data (bytes): The data to load from. +exclude (Iterable[str]): Names of components or serialization fields to exclude. +RETURNS (Language): The Language object.

+

DOCS: https://spacy.io/api/language#from_bytes

+
+
+def from_disk(self,
path: str | pathlib.Path,
*,
exclude: Iterable[str] = [],
overrides: Dict[str, Any] = {}) ‑> spacy.language.Language
+
+
+
+ +Expand source code + +
def from_disk(
+    self,
+    path: Union[str, Path],
+    *,
+    exclude: Iterable[str] = SimpleFrozenList(),
+    overrides: Dict[str, Any] = SimpleFrozenDict(),
+) -> "Language":
+    """Loads state from a directory. Modifies the object in place and
+    returns it. If the saved `Language` object contains a model, the
+    model will be loaded.
+
+    path (str / Path): A path to a directory.
+    exclude (Iterable[str]): Names of components or serialization fields to exclude.
+    RETURNS (Language): The modified `Language` object.
+
+    DOCS: https://spacy.io/api/language#from_disk
+    """
+
+    def deserialize_meta(path: Path) -> None:
+        if path.exists():
+            data = srsly.read_json(path)
+            self.meta.update(data)
+            # self.meta always overrides meta["vectors"] with the metadata
+            # from self.vocab.vectors, so set the name directly
+            self.vocab.vectors.name = data.get("vectors", {}).get("name")
+
+    def deserialize_vocab(path: Path) -> None:
+        if path.exists():
+            self.vocab.from_disk(path, exclude=exclude)
+
+    path = util.ensure_path(path)
+    deserializers = {}
+    if Path(path / "config.cfg").exists():  # type: ignore[operator]
+        deserializers["config.cfg"] = lambda p: self.config.from_disk(
+            p, interpolate=False, overrides=overrides
+        )
+    deserializers["meta.json"] = deserialize_meta  # type: ignore[assignment]
+    deserializers["vocab"] = deserialize_vocab  # type: ignore[assignment]
+    deserializers["tokenizer"] = lambda p: self.tokenizer.from_disk(  # type: ignore[union-attr]
+        p, exclude=["vocab"]
+    )
+    for name, proc in self._components:
+        if name in exclude:
+            continue
+        if not hasattr(proc, "from_disk"):
+            continue
+        deserializers[name] = lambda p, proc=proc: proc.from_disk(  # type: ignore[misc]
+            p, exclude=["vocab"]
+        )
+    if not (path / "vocab").exists() and "vocab" not in exclude:  # type: ignore[operator]
+        # Convert to list here in case exclude is (default) tuple
+        exclude = list(exclude) + ["vocab"]
+    util.from_disk(path, deserializers, exclude)  # type: ignore[arg-type]
+    self._path = path  # type: ignore[assignment]
+    self._link_components()
+    return self
+
+

Loads state from a directory. Modifies the object in place and +returns it. If the saved Language object contains a model, the +model will be loaded.

+

path (str / Path): A path to a directory. +exclude (Iterable[str]): Names of components or serialization fields to exclude. +RETURNS (Language): The modified Language object.

+

DOCS: https://spacy.io/api/language#from_disk

+
+
+def get_pipe(self, name: str) ‑> Callable[[spacy.tokens.doc.Doc], spacy.tokens.doc.Doc] +
+
+
+ +Expand source code + +
def get_pipe(self, name: str) -> PipeCallable:
+    """Get a pipeline component for a given component name.
+
+    name (str): Name of pipeline component to get.
+    RETURNS (callable): The pipeline component.
+
+    DOCS: https://spacy.io/api/language#get_pipe
+    """
+    for pipe_name, component in self._components:
+        if pipe_name == name:
+            return component
+    raise KeyError(Errors.E001.format(name=name, opts=self.component_names))
+
+

Get a pipeline component for a given component name.

+

name (str): Name of pipeline component to get. +RETURNS (callable): The pipeline component.

+

DOCS: https://spacy.io/api/language#get_pipe

+
+
+def get_pipe_config(self, name: str) ‑> confection.Config +
+
+
+ +Expand source code + +
def get_pipe_config(self, name: str) -> Config:
+    """Get the config used to create a pipeline component.
+
+    name (str): The component name.
+    RETURNS (Config): The config used to create the pipeline component.
+    """
+    if name not in self._pipe_configs:
+        raise ValueError(Errors.E960.format(name=name))
+    pipe_config = self._pipe_configs[name]
+    return pipe_config
+
+

Get the config used to create a pipeline component.

+

name (str): The component name. +RETURNS (Config): The config used to create the pipeline component.

+
+
+def get_pipe_meta(self, name: str) ‑> spacy.language.FactoryMeta +
+
+
+ +Expand source code + +
def get_pipe_meta(self, name: str) -> "FactoryMeta":
+    """Get the meta information for a given component name.
+
+    name (str): The component name.
+    RETURNS (FactoryMeta): The meta for the given component name.
+    """
+    if name not in self._pipe_meta:
+        raise ValueError(Errors.E967.format(meta="component", name=name))
+    return self._pipe_meta[name]
+
+

Get the meta information for a given component name.

+

name (str): The component name. +RETURNS (FactoryMeta): The meta for the given component name.

+
+
+def has_pipe(self, name: str) ‑> bool +
+
+
+ +Expand source code + +
def has_pipe(self, name: str) -> bool:
+    """Check if a component name is present in the pipeline. Equivalent to
+    `name in nlp.pipe_names`.
+
+    name (str): Name of the component.
+    RETURNS (bool): Whether a component of the name exists in the pipeline.
+
+    DOCS: https://spacy.io/api/language#has_pipe
+    """
+    return name in self.pipe_names
+
+

Check if a component name is present in the pipeline. Equivalent to +name in nlp.pipe_names.

+

name (str): Name of the component. +RETURNS (bool): Whether a component of the name exists in the pipeline.

+

DOCS: https://spacy.io/api/language#has_pipe

+
+
+def initialize(self,
get_examples: Callable[[], Iterable[spacy.training.example.Example]] | None = None,
*,
sgd: thinc.optimizers.Optimizer | None = None) ‑> thinc.optimizers.Optimizer
+
+
+
+ +Expand source code + +
def initialize(
+    self,
+    get_examples: Optional[Callable[[], Iterable[Example]]] = None,
+    *,
+    sgd: Optional[Optimizer] = None,
+) -> Optimizer:
+    """Initialize the pipe for training, using data examples if available.
+
+    get_examples (Callable[[], Iterable[Example]]): Optional function that
+        returns gold-standard Example objects.
+    sgd (Optional[Optimizer]): An optimizer to use for updates. If not
+        provided, will be created using the .create_optimizer() method.
+    RETURNS (thinc.api.Optimizer): The optimizer.
+
+    DOCS: https://spacy.io/api/language#initialize
+    """
+    if get_examples is None:
+        util.logger.debug(
+            "No 'get_examples' callback provided to 'Language.initialize', creating dummy examples"
+        )
+        doc = Doc(self.vocab, words=["x", "y", "z"])
+
+        def get_examples():
+            return [Example.from_dict(doc, {})]
+
+    if not hasattr(get_examples, "__call__"):
+        err = Errors.E930.format(
+            method="Language.initialize", obj=type(get_examples)
+        )
+        raise TypeError(err)
+    # Make sure the config is interpolated so we can resolve subsections
+    config = self.config.interpolate()
+    # These are the settings provided in the [initialize] block in the config
+    I = registry.resolve(config["initialize"], schema=ConfigSchemaInit)
+    before_init = I["before_init"]
+    if before_init is not None:
+        before_init(self)
+    try:
+        init_vocab(
+            self, data=I["vocab_data"], lookups=I["lookups"], vectors=I["vectors"]
+        )
+    except IOError:
+        raise IOError(Errors.E884.format(vectors=I["vectors"]))
+    if self.vocab.vectors.shape[1] >= 1:
+        ops = get_current_ops()
+        self.vocab.vectors.to_ops(ops)
+    if hasattr(self.tokenizer, "initialize"):
+        tok_settings = validate_init_settings(
+            self.tokenizer.initialize,  # type: ignore[union-attr]
+            I["tokenizer"],
+            section="tokenizer",
+            name="tokenizer",
+        )
+        self.tokenizer.initialize(get_examples, nlp=self, **tok_settings)  # type: ignore[union-attr]
+    for name, proc in self.pipeline:
+        if isinstance(proc, ty.InitializableComponent):
+            p_settings = I["components"].get(name, {})
+            p_settings = validate_init_settings(
+                proc.initialize, p_settings, section="components", name=name
+            )
+            proc.initialize(get_examples, nlp=self, **p_settings)
+    pretrain_cfg = config.get("pretraining")
+    if pretrain_cfg:
+        P = registry.resolve(pretrain_cfg, schema=ConfigSchemaPretrain)
+        init_tok2vec(self, P, I)
+    self._link_components()
+    self._optimizer = sgd
+    if sgd is not None:
+        self._optimizer = sgd
+    elif self._optimizer is None:
+        self._optimizer = self.create_optimizer()
+    after_init = I["after_init"]
+    if after_init is not None:
+        after_init(self)
+    return self._optimizer
+
+

Initialize the pipe for training, using data examples if available.

+

get_examples (Callable[[], Iterable[Example]]): Optional function that +returns gold-standard Example objects. +sgd (Optional[Optimizer]): An optimizer to use for updates. If not +provided, will be created using the .create_optimizer() method. +RETURNS (thinc.api.Optimizer): The optimizer.

+

DOCS: https://spacy.io/api/language#initialize

+
+
+def make_doc(self, text: str) ‑> spacy.tokens.doc.Doc +
+
+
+ +Expand source code + +
def make_doc(self, text: str) -> Doc:
+    """Turn a text into a Doc object.
+
+    text (str): The text to process.
+    RETURNS (Doc): The processed doc.
+    """
+    if len(text) > self.max_length:
+        raise ValueError(
+            Errors.E088.format(length=len(text), max_length=self.max_length)
+        )
+    return self.tokenizer(text)
+
+

Turn a text into a Doc object.

+

text (str): The text to process. +RETURNS (Doc): The processed doc.

+
+
+def memory_zone(self, mem: cymem.cymem.Pool | None = None) ‑> Iterator[cymem.cymem.Pool] +
+
+
+ +Expand source code + +
@contextmanager
+def memory_zone(self, mem: Optional[Pool] = None) -> Iterator[Pool]:
+    """Begin a block where all resources allocated during the block will
+    be freed at the end of it. If a resources was created within the
+    memory zone block, accessing it outside the block is invalid.
+    Behaviour of this invalid access is undefined. Memory zones should
+    not be nested.
+
+    The memory zone is helpful for services that need to process large
+    volumes of text with a defined memory budget.
+
+    Example
+    -------
+    >>> with nlp.memory_zone():
+    ...     for doc in nlp.pipe(texts):
+    ...        process_my_doc(doc)
+    >>> # use_doc(doc) <-- Invalid: doc was allocated in the memory zone
+    """
+    if mem is None:
+        mem = Pool()
+    # The ExitStack allows programmatic nested context managers.
+    # We don't know how many we need, so it would be awkward to have
+    # them as nested blocks.
+    with ExitStack() as stack:
+        contexts = [stack.enter_context(self.vocab.memory_zone(mem))]
+        if hasattr(self.tokenizer, "memory_zone"):
+            contexts.append(stack.enter_context(self.tokenizer.memory_zone(mem)))
+        for _, pipe in self.pipeline:
+            if hasattr(pipe, "memory_zone"):
+                contexts.append(stack.enter_context(pipe.memory_zone(mem)))
+        yield mem
+
+

Begin a block where all resources allocated during the block will +be freed at the end of it. If a resources was created within the +memory zone block, accessing it outside the block is invalid. +Behaviour of this invalid access is undefined. Memory zones should +not be nested.

+

The memory zone is helpful for services that need to process large +volumes of text with a defined memory budget.

+

Example

+
>>> with nlp.memory_zone():
+...     for doc in nlp.pipe(texts):
+...        process_my_doc(doc)
+>>> # use_doc(doc) <-- Invalid: doc was allocated in the memory zone
+
+
+
+def pipe(self,
texts: Iterable[str | spacy.tokens.doc.Doc] | Iterable[Tuple[str | spacy.tokens.doc.Doc, ~_AnyContext]],
*,
as_tuples: bool = False,
batch_size: int | None = None,
disable: Iterable[str] = [],
component_cfg: Dict[str, Dict[str, Any]] | None = None,
n_process: int = 1) ‑> Iterator[spacy.tokens.doc.Doc] | Iterator[Tuple[spacy.tokens.doc.Doc, ~_AnyContext]]
+
+
+
+ +Expand source code + +
def pipe(  # noqa: F811
+    self,
+    texts: Union[
+        Iterable[Union[str, Doc]], Iterable[Tuple[Union[str, Doc], _AnyContext]]
+    ],
+    *,
+    as_tuples: bool = False,
+    batch_size: Optional[int] = None,
+    disable: Iterable[str] = SimpleFrozenList(),
+    component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
+    n_process: int = 1,
+) -> Union[Iterator[Doc], Iterator[Tuple[Doc, _AnyContext]]]:
+    """Process texts as a stream, and yield `Doc` objects in order.
+
+    texts (Iterable[Union[str, Doc]]): A sequence of texts or docs to
+        process.
+    as_tuples (bool): If set to True, inputs should be a sequence of
+        (text, context) tuples. Output will then be a sequence of
+        (doc, context) tuples. Defaults to False.
+    batch_size (Optional[int]): The number of texts to buffer.
+    disable (List[str]): Names of the pipeline components to disable.
+    component_cfg (Dict[str, Dict]): An optional dictionary with extra keyword
+        arguments for specific components.
+    n_process (int): Number of processors to process texts. If -1, set `multiprocessing.cpu_count()`.
+    YIELDS (Doc): Documents in the order of the original text.
+
+    DOCS: https://spacy.io/api/language#pipe
+    """
+    if as_tuples:
+        texts = cast(Iterable[Tuple[Union[str, Doc], _AnyContext]], texts)
+        docs_with_contexts = (
+            self._ensure_doc_with_context(text, context) for text, context in texts
+        )
+        docs = self.pipe(
+            docs_with_contexts,
+            batch_size=batch_size,
+            disable=disable,
+            n_process=n_process,
+            component_cfg=component_cfg,
+        )
+        for doc in docs:
+            context = doc._context
+            doc._context = None
+            yield (doc, context)
+        return
+
+    texts = cast(Iterable[Union[str, Doc]], texts)
+
+    # Set argument defaults
+    if n_process == -1:
+        n_process = mp.cpu_count()
+    if component_cfg is None:
+        component_cfg = {}
+    if batch_size is None:
+        batch_size = self.batch_size
+
+    pipes = (
+        []
+    )  # contains functools.partial objects to easily create multiprocess worker.
+    for name, proc in self.pipeline:
+        if name in disable:
+            continue
+        kwargs = component_cfg.get(name, {})
+        # Allow component_cfg to overwrite the top-level kwargs.
+        kwargs.setdefault("batch_size", batch_size)
+        f = functools.partial(
+            _pipe,
+            proc=proc,
+            name=name,
+            kwargs=kwargs,
+            default_error_handler=self.default_error_handler,
+        )
+        pipes.append(f)
+
+    if n_process != 1:
+        if self._has_gpu_model(disable):
+            warnings.warn(Warnings.W114)
+
+        docs = self._multiprocessing_pipe(texts, pipes, n_process, batch_size)
+    else:
+        # if n_process == 1, no processes are forked.
+        docs = (self._ensure_doc(text) for text in texts)
+        for pipe in pipes:
+            docs = pipe(docs)
+    for doc in docs:
+        yield doc
+
+

Process texts as a stream, and yield Doc objects in order.

+

texts (Iterable[Union[str, Doc]]): A sequence of texts or docs to +process. +as_tuples (bool): If set to True, inputs should be a sequence of +(text, context) tuples. Output will then be a sequence of +(doc, context) tuples. Defaults to False. +batch_size (Optional[int]): The number of texts to buffer. +disable (List[str]): Names of the pipeline components to disable. +component_cfg (Dict[str, Dict]): An optional dictionary with extra keyword +arguments for specific components. +n_process (int): Number of processors to process texts. If -1, set multiprocessing.cpu_count(). +YIELDS (Doc): Documents in the order of the original text.

+

DOCS: https://spacy.io/api/language#pipe

+
+
+def rehearse(self,
examples: Iterable[spacy.training.example.Example],
*,
sgd: thinc.optimizers.Optimizer | None = None,
losses: Dict[str, float] | None = None,
component_cfg: Dict[str, Dict[str, Any]] | None = None,
exclude: Iterable[str] = []) ‑> Dict[str, float]
+
+
+
+ +Expand source code + +
def rehearse(
+    self,
+    examples: Iterable[Example],
+    *,
+    sgd: Optional[Optimizer] = None,
+    losses: Optional[Dict[str, float]] = None,
+    component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
+    exclude: Iterable[str] = SimpleFrozenList(),
+) -> Dict[str, float]:
+    """Make a "rehearsal" update to the models in the pipeline, to prevent
+    forgetting. Rehearsal updates run an initial copy of the model over some
+    data, and update the model so its current predictions are more like the
+    initial ones. This is useful for keeping a pretrained model on-track,
+    even if you're updating it with a smaller set of examples.
+
+    examples (Iterable[Example]): A batch of `Example` objects.
+    sgd (Optional[Optimizer]): An optimizer.
+    component_cfg (Dict[str, Dict]): Config parameters for specific pipeline
+        components, keyed by component name.
+    exclude (Iterable[str]): Names of components that shouldn't be updated.
+    RETURNS (dict): Results from the update.
+
+    EXAMPLE:
+        >>> raw_text_batches = minibatch(raw_texts)
+        >>> for labelled_batch in minibatch(examples):
+        >>>     nlp.update(labelled_batch)
+        >>>     raw_batch = [Example.from_dict(nlp.make_doc(text), {}) for text in next(raw_text_batches)]
+        >>>     nlp.rehearse(raw_batch)
+
+    DOCS: https://spacy.io/api/language#rehearse
+    """
+    if losses is None:
+        losses = {}
+    if isinstance(examples, list) and len(examples) == 0:
+        return losses
+    validate_examples(examples, "Language.rehearse")
+    if sgd is None:
+        if self._optimizer is None:
+            self._optimizer = self.create_optimizer()
+        sgd = self._optimizer
+    pipes = list(self.pipeline)
+    random.shuffle(pipes)
+    if component_cfg is None:
+        component_cfg = {}
+    grads = {}
+
+    def get_grads(key, W, dW):
+        grads[key] = (W, dW)
+        return W, dW
+
+    get_grads.learn_rate = sgd.learn_rate  # type: ignore[attr-defined, union-attr]
+    get_grads.b1 = sgd.b1  # type: ignore[attr-defined, union-attr]
+    get_grads.b2 = sgd.b2  # type: ignore[attr-defined, union-attr]
+    for name, proc in pipes:
+        if name in exclude or not hasattr(proc, "rehearse"):
+            continue
+        grads = {}
+        proc.rehearse(  # type: ignore[attr-defined]
+            examples, sgd=get_grads, losses=losses, **component_cfg.get(name, {})
+        )
+    for key, (W, dW) in grads.items():
+        sgd(key, W, dW)  # type: ignore[call-arg, misc]
+    return losses
+
+

Make a "rehearsal" update to the models in the pipeline, to prevent +forgetting. Rehearsal updates run an initial copy of the model over some +data, and update the model so its current predictions are more like the +initial ones. This is useful for keeping a pretrained model on-track, +even if you're updating it with a smaller set of examples.

+

examples (Iterable[Example]): A batch of Example objects. +sgd (Optional[Optimizer]): An optimizer. +component_cfg (Dict[str, Dict]): Config parameters for specific pipeline +components, keyed by component name. +exclude (Iterable[str]): Names of components that shouldn't be updated. +RETURNS (dict): Results from the update.

+

Example

+
>>> raw_text_batches = minibatch(raw_texts)
+>>> for labelled_batch in minibatch(examples):
+>>>     nlp.update(labelled_batch)
+>>>     raw_batch = [Example.from_dict(nlp.make_doc(text), {}) for text in next(raw_text_batches)]
+>>>     nlp.rehearse(raw_batch)
+
+

DOCS: https://spacy.io/api/language#rehearse

+
+
+def remove_pipe(self, name: str) ‑> Tuple[str, Callable[[spacy.tokens.doc.Doc], spacy.tokens.doc.Doc]] +
+
+
+ +Expand source code + +
def remove_pipe(self, name: str) -> Tuple[str, PipeCallable]:
+    """Remove a component from the pipeline.
+
+    name (str): Name of the component to remove.
+    RETURNS (Tuple[str, Callable[[Doc], Doc]]): A `(name, component)` tuple of the removed component.
+
+    DOCS: https://spacy.io/api/language#remove_pipe
+    """
+    if name not in self.component_names:
+        raise ValueError(Errors.E001.format(name=name, opts=self.component_names))
+    removed = self._components.pop(self.component_names.index(name))
+    # We're only removing the component itself from the metas/configs here
+    # because factory may be used for something else
+    self._pipe_meta.pop(name)
+    self._pipe_configs.pop(name)
+    self.meta.get("_sourced_vectors_hashes", {}).pop(name, None)
+    # Make sure name is removed from the [initialize] config
+    if name in self._config["initialize"]["components"]:
+        self._config["initialize"]["components"].pop(name)
+    # Make sure the name is also removed from the set of disabled components
+    if name in self.disabled:
+        self._disabled.remove(name)
+    self._link_components()
+    return removed
+
+

Remove a component from the pipeline.

+

name (str): Name of the component to remove. +RETURNS (Tuple[str, Callable[[Doc], Doc]]): A (name, component) tuple of the removed component.

+

DOCS: https://spacy.io/api/language#remove_pipe

+
+
+def rename_pipe(self, old_name: str, new_name: str) ‑> None +
+
+
+ +Expand source code + +
def rename_pipe(self, old_name: str, new_name: str) -> None:
+    """Rename a pipeline component.
+
+    old_name (str): Name of the component to rename.
+    new_name (str): New name of the component.
+
+    DOCS: https://spacy.io/api/language#rename_pipe
+    """
+    if old_name not in self.component_names:
+        raise ValueError(
+            Errors.E001.format(name=old_name, opts=self.component_names)
+        )
+    if new_name in self.component_names:
+        raise ValueError(
+            Errors.E007.format(name=new_name, opts=self.component_names)
+        )
+    i = self.component_names.index(old_name)
+    self._components[i] = (new_name, self._components[i][1])
+    self._pipe_meta[new_name] = self._pipe_meta.pop(old_name)
+    self._pipe_configs[new_name] = self._pipe_configs.pop(old_name)
+    # Make sure [initialize] config is adjusted
+    if old_name in self._config["initialize"]["components"]:
+        init_cfg = self._config["initialize"]["components"].pop(old_name)
+        self._config["initialize"]["components"][new_name] = init_cfg
+    self._link_components()
+
+

Rename a pipeline component.

+

old_name (str): Name of the component to rename. +new_name (str): New name of the component.

+

DOCS: https://spacy.io/api/language#rename_pipe

+
+
+def replace_listeners(self, tok2vec_name: str, pipe_name: str, listeners: Iterable[str]) ‑> None +
+
+
+ +Expand source code + +
def replace_listeners(
+    self,
+    tok2vec_name: str,
+    pipe_name: str,
+    listeners: Iterable[str],
+) -> None:
+    """Find listener layers (connecting to a token-to-vector embedding
+    component) of a given pipeline component model and replace
+    them with a standalone copy of the token-to-vector layer. This can be
+    useful when training a pipeline with components sourced from an existing
+    pipeline: if multiple components (e.g. tagger, parser, NER) listen to
+    the same tok2vec component, but some of them are frozen and not updated,
+    their performance may degrade significantly as the tok2vec component is
+    updated with new data. To prevent this, listeners can be replaced with
+    a standalone tok2vec layer that is owned by the component and doesn't
+    change if the component isn't updated.
+
+    tok2vec_name (str): Name of the token-to-vector component, typically
+        "tok2vec" or "transformer".
+    pipe_name (str): Name of pipeline component to replace listeners for.
+    listeners (Iterable[str]): The paths to the listeners, relative to the
+        component config, e.g. ["model.tok2vec"]. Typically, implementations
+        will only connect to one tok2vec component, [model.tok2vec], but in
+        theory, custom models can use multiple listeners. The value here can
+        either be an empty list to not replace any listeners, or a complete
+        (!) list of the paths to all listener layers used by the model.
+
+    DOCS: https://spacy.io/api/language#replace_listeners
+    """
+    if tok2vec_name not in self.pipe_names:
+        err = Errors.E889.format(
+            tok2vec=tok2vec_name,
+            name=pipe_name,
+            unknown=tok2vec_name,
+            opts=", ".join(self.pipe_names),
+        )
+        raise ValueError(err)
+    if pipe_name not in self.pipe_names:
+        err = Errors.E889.format(
+            tok2vec=tok2vec_name,
+            name=pipe_name,
+            unknown=pipe_name,
+            opts=", ".join(self.pipe_names),
+        )
+        raise ValueError(err)
+    tok2vec = self.get_pipe(tok2vec_name)
+    tok2vec_cfg = self.get_pipe_config(tok2vec_name)
+    if not isinstance(tok2vec, ty.ListenedToComponent):
+        raise ValueError(Errors.E888.format(name=tok2vec_name, pipe=type(tok2vec)))
+    tok2vec_model = tok2vec.model
+    pipe_listeners = tok2vec.listener_map.get(pipe_name, [])
+    pipe = self.get_pipe(pipe_name)
+    pipe_cfg = self._pipe_configs[pipe_name]
+    if listeners:
+        util.logger.debug("Replacing listeners of component '%s'", pipe_name)
+        if len(list(listeners)) != len(pipe_listeners):
+            # The number of listeners defined in the component model doesn't
+            # match the listeners to replace, so we won't be able to update
+            # the nodes and generate a matching config
+            err = Errors.E887.format(
+                name=pipe_name,
+                tok2vec=tok2vec_name,
+                paths=listeners,
+                n_listeners=len(pipe_listeners),
+            )
+            raise ValueError(err)
+        # Update the config accordingly by copying the tok2vec model to all
+        # sections defined in the listener paths
+        for listener_path in listeners:
+            # Check if the path actually exists in the config
+            try:
+                util.dot_to_object(pipe_cfg, listener_path)
+            except KeyError:
+                err = Errors.E886.format(
+                    name=pipe_name, tok2vec=tok2vec_name, path=listener_path
+                )
+                raise ValueError(err)
+            new_config = tok2vec_cfg["model"]
+            if "replace_listener_cfg" in tok2vec_model.attrs:
+                replace_func = tok2vec_model.attrs["replace_listener_cfg"]
+                new_config = replace_func(
+                    tok2vec_cfg["model"], pipe_cfg["model"]["tok2vec"]
+                )
+            util.set_dot_to_object(pipe_cfg, listener_path, new_config)
+        # Go over the listener layers and replace them
+        for listener in pipe_listeners:
+            new_model = tok2vec_model.copy()
+            replace_listener_func = tok2vec_model.attrs.get("replace_listener")
+            if replace_listener_func is not None:
+                # Pass the extra args to the callback without breaking compatibility with
+                # old library versions that only expect a single parameter.
+                num_params = len(
+                    inspect.signature(replace_listener_func).parameters
+                )
+                if num_params == 1:
+                    new_model = replace_listener_func(new_model)
+                elif num_params == 3:
+                    new_model = replace_listener_func(new_model, listener, tok2vec)
+                else:
+                    raise ValueError(Errors.E1055.format(num_params=num_params))
+
+            util.replace_model_node(pipe.model, listener, new_model)  # type: ignore[attr-defined]
+            tok2vec.remove_listener(listener, pipe_name)
+
+

Find listener layers (connecting to a token-to-vector embedding +component) of a given pipeline component model and replace +them with a standalone copy of the token-to-vector layer. This can be +useful when training a pipeline with components sourced from an existing +pipeline: if multiple components (e.g. tagger, parser, NER) listen to +the same tok2vec component, but some of them are frozen and not updated, +their performance may degrade significantly as the tok2vec component is +updated with new data. To prevent this, listeners can be replaced with +a standalone tok2vec layer that is owned by the component and doesn't +change if the component isn't updated.

+

tok2vec_name (str): Name of the token-to-vector component, typically +"tok2vec" or "transformer". +pipe_name (str): Name of pipeline component to replace listeners for. +listeners (Iterable[str]): The paths to the listeners, relative to the +component config, e.g. ["model.tok2vec"]. Typically, implementations +will only connect to one tok2vec component, [model.tok2vec], but in +theory, custom models can use multiple listeners. The value here can +either be an empty list to not replace any listeners, or a complete +(!) list of the paths to all listener layers used by the model.

+

DOCS: https://spacy.io/api/language#replace_listeners

+
+
+def replace_pipe(self,
name: str,
factory_name: str,
*,
config: Dict[str, Any] = {},
validate: bool = True) ‑> Callable[[spacy.tokens.doc.Doc], spacy.tokens.doc.Doc]
+
+
+
+ +Expand source code + +
def replace_pipe(
+    self,
+    name: str,
+    factory_name: str,
+    *,
+    config: Dict[str, Any] = SimpleFrozenDict(),
+    validate: bool = True,
+) -> PipeCallable:
+    """Replace a component in the pipeline.
+
+    name (str): Name of the component to replace.
+    factory_name (str): Factory name of replacement component.
+    config (Optional[Dict[str, Any]]): Config parameters to use for this
+        component. Will be merged with default config, if available.
+    validate (bool): Whether to validate the component config against the
+        arguments and types expected by the factory.
+    RETURNS (Callable[[Doc], Doc]): The new pipeline component.
+
+    DOCS: https://spacy.io/api/language#replace_pipe
+    """
+    if name not in self.component_names:
+        raise ValueError(Errors.E001.format(name=name, opts=self.pipe_names))
+    if hasattr(factory_name, "__call__"):
+        err = Errors.E968.format(component=repr(factory_name), name=name)
+        raise ValueError(err)
+    # We need to delegate to Language.add_pipe here instead of just writing
+    # to Language.pipeline to make sure the configs are handled correctly
+    pipe_index = self.component_names.index(name)
+    self.remove_pipe(name)
+    if not len(self._components) or pipe_index == len(self._components):
+        # we have no components to insert before/after, or we're replacing the last component
+        return self.add_pipe(
+            factory_name, name=name, config=config, validate=validate
+        )
+    else:
+        return self.add_pipe(
+            factory_name,
+            name=name,
+            before=pipe_index,
+            config=config,
+            validate=validate,
+        )
+
+

Replace a component in the pipeline.

+

name (str): Name of the component to replace. +factory_name (str): Factory name of replacement component. +config (Optional[Dict[str, Any]]): Config parameters to use for this +component. Will be merged with default config, if available. +validate (bool): Whether to validate the component config against the +arguments and types expected by the factory. +RETURNS (Callable[[Doc], Doc]): The new pipeline component.

+

DOCS: https://spacy.io/api/language#replace_pipe

+
+
+def resume_training(self, *, sgd: thinc.optimizers.Optimizer | None = None) ‑> thinc.optimizers.Optimizer +
+
+
+ +Expand source code + +
def resume_training(self, *, sgd: Optional[Optimizer] = None) -> Optimizer:
+    """Continue training a pretrained model.
+
+    Create and return an optimizer, and initialize "rehearsal" for any pipeline
+    component that has a .rehearse() method. Rehearsal is used to prevent
+    models from "forgetting" their initialized "knowledge". To perform
+    rehearsal, collect samples of text you want the models to retain performance
+    on, and call nlp.rehearse() with a batch of Example objects.
+
+    RETURNS (Optimizer): The optimizer.
+
+    DOCS: https://spacy.io/api/language#resume_training
+    """
+    ops = get_current_ops()
+    if self.vocab.vectors.shape[1] >= 1:
+        self.vocab.vectors.to_ops(ops)
+    for name, proc in self.pipeline:
+        if hasattr(proc, "_rehearsal_model"):
+            proc._rehearsal_model = deepcopy(proc.model)  # type: ignore[attr-defined]
+    if sgd is not None:
+        self._optimizer = sgd
+    elif self._optimizer is None:
+        self._optimizer = self.create_optimizer()
+    return self._optimizer
+
+

Continue training a pretrained model.

+

Create and return an optimizer, and initialize "rehearsal" for any pipeline +component that has a .rehearse() method. Rehearsal is used to prevent +models from "forgetting" their initialized "knowledge". To perform +rehearsal, collect samples of text you want the models to retain performance +on, and call nlp.rehearse() with a batch of Example objects.

+

RETURNS (Optimizer): The optimizer.

+

DOCS: https://spacy.io/api/language#resume_training

+
+
+def select_pipes(self,
*,
disable: str | Iterable[str] | None = None,
enable: str | Iterable[str] | None = None) ‑> spacy.language.DisabledPipes
+
+
+
+ +Expand source code + +
def select_pipes(
+    self,
+    *,
+    disable: Optional[Union[str, Iterable[str]]] = None,
+    enable: Optional[Union[str, Iterable[str]]] = None,
+) -> "DisabledPipes":
+    """Disable one or more pipeline components. If used as a context
+    manager, the pipeline will be restored to the initial state at the end
+    of the block. Otherwise, a DisabledPipes object is returned, that has
+    a `.restore()` method you can use to undo your changes.
+
+    disable (str or iterable): The name(s) of the pipes to disable
+    enable (str or iterable): The name(s) of the pipes to enable - all others will be disabled
+
+    DOCS: https://spacy.io/api/language#select_pipes
+    """
+    if enable is None and disable is None:
+        raise ValueError(Errors.E991)
+    if isinstance(disable, str):
+        disable = [disable]
+    if enable is not None:
+        if isinstance(enable, str):
+            enable = [enable]
+        to_disable = [pipe for pipe in self.pipe_names if pipe not in enable]
+        # raise an error if the enable and disable keywords are not consistent
+        if disable is not None and disable != to_disable:
+            raise ValueError(
+                Errors.E992.format(
+                    enable=enable, disable=disable, names=self.pipe_names
+                )
+            )
+        disable = to_disable
+    assert disable is not None
+    # DisabledPipes will restore the pipes in 'disable' when it's done, so we need to exclude
+    # those pipes that were already disabled.
+    disable = [d for d in disable if d not in self._disabled]
+    return DisabledPipes(self, disable)
+
+

Disable one or more pipeline components. If used as a context +manager, the pipeline will be restored to the initial state at the end +of the block. Otherwise, a DisabledPipes object is returned, that has +a .restore() method you can use to undo your changes.

+

disable (str or iterable): The name(s) of the pipes to disable +enable (str or iterable): The name(s) of the pipes to enable - all others will be disabled

+

DOCS: https://spacy.io/api/language#select_pipes

+
+
+def set_error_handler(self,
error_handler: Callable[[str, Callable[[spacy.tokens.doc.Doc], spacy.tokens.doc.Doc], List[spacy.tokens.doc.Doc], Exception], NoReturn])
+
+
+
+ +Expand source code + +
def set_error_handler(
+    self,
+    error_handler: Callable[[str, PipeCallable, List[Doc], Exception], NoReturn],
+):
+    """Set an error handler object for all the components in the pipeline
+    that implement a set_error_handler function.
+
+    error_handler (Callable[[str, Callable[[Doc], Doc], List[Doc], Exception], NoReturn]):
+        Function that deals with a failing batch of documents. This callable
+        function should take in the component's name, the component itself,
+        the offending batch of documents, and the exception that was thrown.
+    DOCS: https://spacy.io/api/language#set_error_handler
+    """
+    self.default_error_handler = error_handler
+    for name, pipe in self.pipeline:
+        if hasattr(pipe, "set_error_handler"):
+            pipe.set_error_handler(error_handler)
+
+

Set an error handler object for all the components in the pipeline +that implement a set_error_handler function.

+

error_handler (Callable[[str, Callable[[Doc], Doc], List[Doc], Exception], NoReturn]): +Function that deals with a failing batch of documents. This callable +function should take in the component's name, the component itself, +the offending batch of documents, and the exception that was thrown. +DOCS: https://spacy.io/api/language#set_error_handler

+
+
+def to_bytes(self, *, exclude: Iterable[str] = []) ‑> bytes +
+
+
+ +Expand source code + +
def to_bytes(self, *, exclude: Iterable[str] = SimpleFrozenList()) -> bytes:
+    """Serialize the current state to a binary string.
+
+    exclude (Iterable[str]): Names of components or serialization fields to exclude.
+    RETURNS (bytes): The serialized form of the `Language` object.
+
+    DOCS: https://spacy.io/api/language#to_bytes
+    """
+    serializers: Dict[str, Callable[[], bytes]] = {}
+    serializers["vocab"] = lambda: self.vocab.to_bytes(exclude=exclude)
+    serializers["tokenizer"] = lambda: self.tokenizer.to_bytes(exclude=["vocab"])  # type: ignore[union-attr]
+    serializers["meta.json"] = lambda: srsly.json_dumps(
+        _replace_numpy_floats(self.meta)
+    )
+    serializers["config.cfg"] = lambda: self.config.to_bytes()
+    for name, proc in self._components:
+        if name in exclude:
+            continue
+        if not hasattr(proc, "to_bytes"):
+            continue
+        serializers[name] = lambda proc=proc: proc.to_bytes(exclude=["vocab"])  # type: ignore[misc]
+    return util.to_bytes(serializers, exclude)
+
+

Serialize the current state to a binary string.

+

exclude (Iterable[str]): Names of components or serialization fields to exclude. +RETURNS (bytes): The serialized form of the Language object.

+

DOCS: https://spacy.io/api/language#to_bytes

+
+
+def to_disk(self, path: str | pathlib.Path, *, exclude: Iterable[str] = []) ‑> None +
+
+
+ +Expand source code + +
def to_disk(
+    self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
+) -> None:
+    """Save the current state to a directory.  If a model is loaded, this
+    will include the model.
+
+    path (str / Path): Path to a directory, which will be created if
+        it doesn't exist.
+    exclude (Iterable[str]): Names of components or serialization fields to exclude.
+
+    DOCS: https://spacy.io/api/language#to_disk
+    """
+    path = util.ensure_path(path)
+    serializers = {}
+    serializers["tokenizer"] = lambda p: self.tokenizer.to_disk(  # type: ignore[union-attr]
+        p, exclude=["vocab"]
+    )
+    serializers["meta.json"] = lambda p: srsly.write_json(
+        p, _replace_numpy_floats(self.meta)
+    )
+    serializers["config.cfg"] = lambda p: self.config.to_disk(p)
+    for name, proc in self._components:
+        if name in exclude:
+            continue
+        if not hasattr(proc, "to_disk"):
+            continue
+        serializers[name] = lambda p, proc=proc: proc.to_disk(p, exclude=["vocab"])  # type: ignore[misc]
+    serializers["vocab"] = lambda p: self.vocab.to_disk(p, exclude=exclude)
+    util.to_disk(path, serializers, exclude)
+
+

Save the current state to a directory. +If a model is loaded, this +will include the model.

+

path (str / Path): Path to a directory, which will be created if +it doesn't exist. +exclude (Iterable[str]): Names of components or serialization fields to exclude.

+

DOCS: https://spacy.io/api/language#to_disk

+
+
+def update(self,
examples: Iterable[spacy.training.example.Example],
*,
drop: float = 0.0,
sgd: thinc.optimizers.Optimizer | None = None,
losses: Dict[str, float] | None = None,
component_cfg: Dict[str, Dict[str, Any]] | None = None,
exclude: Iterable[str] = [],
annotates: Iterable[str] = [])
+
+
+
+ +Expand source code + +
def update(
+    self,
+    examples: Iterable[Example],
+    _: Optional[Any] = None,
+    *,
+    drop: float = 0.0,
+    sgd: Optional[Optimizer] = None,
+    losses: Optional[Dict[str, float]] = None,
+    component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
+    exclude: Iterable[str] = SimpleFrozenList(),
+    annotates: Iterable[str] = SimpleFrozenList(),
+):
+    """Update the models in the pipeline.
+
+    examples (Iterable[Example]): A batch of examples
+    _: Should not be set - serves to catch backwards-incompatible scripts.
+    drop (float): The dropout rate.
+    sgd (Optimizer): An optimizer.
+    losses (Dict[str, float]): Dictionary to update with the loss, keyed by
+        component.
+    component_cfg (Dict[str, Dict]): Config parameters for specific pipeline
+        components, keyed by component name.
+    exclude (Iterable[str]): Names of components that shouldn't be updated.
+    annotates (Iterable[str]): Names of components that should set
+        annotations on the predicted examples after updating.
+    RETURNS (Dict[str, float]): The updated losses dictionary
+
+    DOCS: https://spacy.io/api/language#update
+    """
+    if _ is not None:
+        raise ValueError(Errors.E989)
+    if losses is None:
+        losses = {}
+    if isinstance(examples, list) and len(examples) == 0:
+        return losses
+    validate_examples(examples, "Language.update")
+    examples = _copy_examples(examples)
+    if sgd is None:
+        if self._optimizer is None:
+            self._optimizer = self.create_optimizer()
+        sgd = self._optimizer
+    if component_cfg is None:
+        component_cfg = {}
+    pipe_kwargs = {}
+    for i, (name, proc) in enumerate(self.pipeline):
+        component_cfg.setdefault(name, {})
+        pipe_kwargs[name] = deepcopy(component_cfg[name])
+        component_cfg[name].setdefault("drop", drop)
+        pipe_kwargs[name].setdefault("batch_size", self.batch_size)
+    for name, proc in self.pipeline:
+        # ignore statements are used here because mypy ignores hasattr
+        if name not in exclude and hasattr(proc, "update"):
+            proc.update(examples, sgd=None, losses=losses, **component_cfg[name])  # type: ignore
+        if sgd not in (None, False):
+            if (
+                name not in exclude
+                and isinstance(proc, ty.TrainableComponent)
+                and proc.is_trainable
+                and proc.model not in (True, False, None)
+            ):
+                proc.finish_update(sgd)
+        if name in annotates:
+            for doc, eg in zip(
+                _pipe(
+                    (eg.predicted for eg in examples),
+                    proc=proc,
+                    name=name,
+                    default_error_handler=self.default_error_handler,
+                    kwargs=pipe_kwargs[name],
+                ),
+                examples,
+            ):
+                eg.predicted = doc
+    return _replace_numpy_floats(losses)
+
+

Update the models in the pipeline.

+

examples (Iterable[Example]): A batch of examples +_: Should not be set - serves to catch backwards-incompatible scripts. +drop (float): The dropout rate. +sgd (Optimizer): An optimizer. +losses (Dict[str, float]): Dictionary to update with the loss, keyed by +component. +component_cfg (Dict[str, Dict]): Config parameters for specific pipeline +components, keyed by component name. +exclude (Iterable[str]): Names of components that shouldn't be updated. +annotates (Iterable[str]): Names of components that should set +annotations on the predicted examples after updating. +RETURNS (Dict[str, float]): The updated losses dictionary

+

DOCS: https://spacy.io/api/language#update

+
+
+def use_params(self, params: dict | None) +
+
+
+ +Expand source code + +
@contextmanager
+def use_params(self, params: Optional[dict]):
+    """Replace weights of models in the pipeline with those provided in the
+    params dictionary. Can be used as a contextmanager, in which case,
+    models go back to their original weights after the block.
+
+    params (dict): A dictionary of parameters keyed by model ID.
+
+    EXAMPLE:
+        >>> with nlp.use_params(optimizer.averages):
+        >>>     nlp.to_disk("/tmp/checkpoint")
+
+    DOCS: https://spacy.io/api/language#use_params
+    """
+    if not params:
+        yield
+    else:
+        contexts = [
+            pipe.use_params(params)  # type: ignore[attr-defined]
+            for name, pipe in self.pipeline
+            if hasattr(pipe, "use_params") and hasattr(pipe, "model")
+        ]
+        # TODO: Having trouble with contextlib
+        # Workaround: these aren't actually context managers atm.
+        for context in contexts:
+            try:
+                next(context)
+            except StopIteration:
+                pass
+        yield
+        for context in contexts:
+            try:
+                next(context)
+            except StopIteration:
+                pass
+
+

Replace weights of models in the pipeline with those provided in the +params dictionary. Can be used as a contextmanager, in which case, +models go back to their original weights after the block.

+

params (dict): A dictionary of parameters keyed by model ID.

+

Example

+
>>> with nlp.use_params(optimizer.averages):
+>>>     nlp.to_disk("/tmp/checkpoint")
+
+

DOCS: https://spacy.io/api/language#use_params

+
+
+
+
+class SentenceTransformer +(model_name_or_path: str | None = None,
modules: Iterable[nn.Module] | None = None,
device: str | None = None,
prompts: dict[str, str] | None = None,
default_prompt_name: str | None = None,
similarity_fn_name: str | SimilarityFunction | None = None,
cache_folder: str | None = None,
trust_remote_code: bool = False,
revision: str | None = None,
local_files_only: bool = False,
token: bool | str | None = None,
use_auth_token: bool | str | None = None,
truncate_dim: int | None = None,
model_kwargs: dict[str, Any] | None = None,
tokenizer_kwargs: dict[str, Any] | None = None,
config_kwargs: dict[str, Any] | None = None,
model_card_data: SentenceTransformerModelCardData | None = None,
backend: "Literal['torch', 'onnx', 'openvino']" = 'torch')
+
+
+
+ +Expand source code + +
class SentenceTransformer(nn.Sequential, FitMixin, PeftAdapterMixin):
+    """
+    Loads or creates a SentenceTransformer model that can be used to map sentences / text to embeddings.
+
+    Args:
+        model_name_or_path (str, optional): If it is a filepath on disc, it loads the model from that path. If it is not a path,
+            it first tries to download a pre-trained SentenceTransformer model. If that fails, tries to construct a model
+            from the Hugging Face Hub with that name.
+        modules (Iterable[nn.Module], optional): A list of torch Modules that should be called sequentially, can be used to create custom
+            SentenceTransformer models from scratch.
+        device (str, optional): Device (like "cuda", "cpu", "mps", "npu") that should be used for computation. If None, checks if a GPU
+            can be used.
+        prompts (Dict[str, str], optional): A dictionary with prompts for the model. The key is the prompt name, the value is the prompt text.
+            The prompt text will be prepended before any text to encode. For example:
+            `{"query": "query: ", "passage": "passage: "}` or `{"clustering": "Identify the main category based on the
+            titles in "}`.
+        default_prompt_name (str, optional): The name of the prompt that should be used by default. If not set,
+            no prompt will be applied.
+        similarity_fn_name (str or SimilarityFunction, optional): The name of the similarity function to use. Valid options are "cosine", "dot",
+            "euclidean", and "manhattan". If not set, it is automatically set to "cosine" if `similarity` or
+            `similarity_pairwise` are called while `model.similarity_fn_name` is still `None`.
+        cache_folder (str, optional): Path to store models. Can also be set by the SENTENCE_TRANSFORMERS_HOME environment variable.
+        trust_remote_code (bool, optional): Whether or not to allow for custom models defined on the Hub in their own modeling files.
+            This option should only be set to True for repositories you trust and in which you have read the code, as it
+            will execute code present on the Hub on your local machine.
+        revision (str, optional): The specific model version to use. It can be a branch name, a tag name, or a commit id,
+            for a stored model on Hugging Face.
+        local_files_only (bool, optional): Whether or not to only look at local files (i.e., do not try to download the model).
+        token (bool or str, optional): Hugging Face authentication token to download private models.
+        use_auth_token (bool or str, optional): Deprecated argument. Please use `token` instead.
+        truncate_dim (int, optional): The dimension to truncate sentence embeddings to. `None` does no truncation. Truncation is
+            only applicable during inference when :meth:`SentenceTransformer.encode` is called.
+        model_kwargs (Dict[str, Any], optional): Additional model configuration parameters to be passed to the Hugging Face Transformers model.
+            Particularly useful options are:
+
+            - ``torch_dtype``: Override the default `torch.dtype` and load the model under a specific `dtype`.
+              The different options are:
+
+                    1. ``torch.float16``, ``torch.bfloat16`` or ``torch.float``: load in a specified
+                    ``dtype``, ignoring the model's ``config.torch_dtype`` if one exists. If not specified - the model will
+                    get loaded in ``torch.float`` (fp32).
+
+                    2. ``"auto"`` - A ``torch_dtype`` entry in the ``config.json`` file of the model will be
+                    attempted to be used. If this entry isn't found then next check the ``dtype`` of the first weight in
+                    the checkpoint that's of a floating point type and use that as ``dtype``. This will load the model
+                    using the ``dtype`` it was saved in at the end of the training. It can't be used as an indicator of how
+                    the model was trained. Since it could be trained in one of half precision dtypes, but saved in fp32.
+            - ``attn_implementation``: The attention implementation to use in the model (if relevant). Can be any of
+              `"eager"` (manual implementation of the attention), `"sdpa"` (using `F.scaled_dot_product_attention
+              <https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html>`_),
+              or `"flash_attention_2"` (using `Dao-AILab/flash-attention <https://github.com/Dao-AILab/flash-attention>`_).
+              By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"`
+              implementation.
+            - ``provider``: If backend is "onnx", this is the provider to use for inference, for example "CPUExecutionProvider",
+              "CUDAExecutionProvider", etc. See https://onnxruntime.ai/docs/execution-providers/ for all ONNX execution providers.
+            - ``file_name``: If backend is "onnx" or "openvino", this is the file name to load, useful for loading optimized
+              or quantized ONNX or OpenVINO models.
+            - ``export``: If backend is "onnx" or "openvino", then this is a boolean flag specifying whether this model should
+              be exported to the backend. If not specified, the model will be exported only if the model repository or directory
+              does not already contain an exported model.
+
+            See the `PreTrainedModel.from_pretrained
+            <https://huggingface.co/docs/transformers/en/main_classes/model#transformers.PreTrainedModel.from_pretrained>`_
+            documentation for more details.
+        tokenizer_kwargs (Dict[str, Any], optional): Additional tokenizer configuration parameters to be passed to the Hugging Face Transformers tokenizer.
+            See the `AutoTokenizer.from_pretrained
+            <https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoTokenizer.from_pretrained>`_
+            documentation for more details.
+        config_kwargs (Dict[str, Any], optional): Additional model configuration parameters to be passed to the Hugging Face Transformers config.
+            See the `AutoConfig.from_pretrained
+            <https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoConfig.from_pretrained>`_
+            documentation for more details.
+        model_card_data (:class:`~sentence_transformers.model_card.SentenceTransformerModelCardData`, optional): A model
+            card data object that contains information about the model. This is used to generate a model card when saving
+            the model. If not set, a default model card data object is created.
+        backend (str): The backend to use for inference. Can be one of "torch" (default), "onnx", or "openvino".
+            See https://sbert.net/docs/sentence_transformer/usage/efficiency.html for benchmarking information
+            on the different backends.
+
+    Example:
+        ::
+
+            from sentence_transformers import SentenceTransformer
+
+            # Load a pre-trained SentenceTransformer model
+            model = SentenceTransformer('all-mpnet-base-v2')
+
+            # Encode some texts
+            sentences = [
+                "The weather is lovely today.",
+                "It's so sunny outside!",
+                "He drove to the stadium.",
+            ]
+            embeddings = model.encode(sentences)
+            print(embeddings.shape)
+            # (3, 768)
+
+            # Get the similarity scores between all sentences
+            similarities = model.similarity(embeddings, embeddings)
+            print(similarities)
+            # tensor([[1.0000, 0.6817, 0.0492],
+            #         [0.6817, 1.0000, 0.0421],
+            #         [0.0492, 0.0421, 1.0000]])
+    """
+
+    def __init__(
+        self,
+        model_name_or_path: str | None = None,
+        modules: Iterable[nn.Module] | None = None,
+        device: str | None = None,
+        prompts: dict[str, str] | None = None,
+        default_prompt_name: str | None = None,
+        similarity_fn_name: str | SimilarityFunction | None = None,
+        cache_folder: str | None = None,
+        trust_remote_code: bool = False,
+        revision: str | None = None,
+        local_files_only: bool = False,
+        token: bool | str | None = None,
+        use_auth_token: bool | str | None = None,
+        truncate_dim: int | None = None,
+        model_kwargs: dict[str, Any] | None = None,
+        tokenizer_kwargs: dict[str, Any] | None = None,
+        config_kwargs: dict[str, Any] | None = None,
+        model_card_data: SentenceTransformerModelCardData | None = None,
+        backend: Literal["torch", "onnx", "openvino"] = "torch",
+    ) -> None:
+        # Note: self._load_sbert_model can also update `self.prompts` and `self.default_prompt_name`
+        self.prompts = prompts or {}
+        self.default_prompt_name = default_prompt_name
+        self.similarity_fn_name = similarity_fn_name
+        self.trust_remote_code = trust_remote_code
+        self.truncate_dim = truncate_dim
+        self.model_card_data = model_card_data or SentenceTransformerModelCardData()
+        self.module_kwargs = None
+        self._model_card_vars = {}
+        self._model_card_text = None
+        self._model_config = {}
+        self.backend = backend
+        if use_auth_token is not None:
+            warnings.warn(
+                "The `use_auth_token` argument is deprecated and will be removed in v4 of SentenceTransformers.",
+                FutureWarning,
+            )
+            if token is not None:
+                raise ValueError(
+                    "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+                )
+            token = use_auth_token
+
+        if cache_folder is None:
+            cache_folder = os.getenv("SENTENCE_TRANSFORMERS_HOME")
+
+        if device is None:
+            device = get_device_name()
+            logger.info(f"Use pytorch device_name: {device}")
+
+        if device == "hpu" and importlib.util.find_spec("optimum") is not None:
+            from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
+
+            adapt_transformers_to_gaudi()
+
+        if model_name_or_path is not None and model_name_or_path != "":
+            logger.info(f"Load pretrained SentenceTransformer: {model_name_or_path}")
+
+            # Old models that don't belong to any organization
+            basic_transformer_models = [
+                "albert-base-v1",
+                "albert-base-v2",
+                "albert-large-v1",
+                "albert-large-v2",
+                "albert-xlarge-v1",
+                "albert-xlarge-v2",
+                "albert-xxlarge-v1",
+                "albert-xxlarge-v2",
+                "bert-base-cased-finetuned-mrpc",
+                "bert-base-cased",
+                "bert-base-chinese",
+                "bert-base-german-cased",
+                "bert-base-german-dbmdz-cased",
+                "bert-base-german-dbmdz-uncased",
+                "bert-base-multilingual-cased",
+                "bert-base-multilingual-uncased",
+                "bert-base-uncased",
+                "bert-large-cased-whole-word-masking-finetuned-squad",
+                "bert-large-cased-whole-word-masking",
+                "bert-large-cased",
+                "bert-large-uncased-whole-word-masking-finetuned-squad",
+                "bert-large-uncased-whole-word-masking",
+                "bert-large-uncased",
+                "camembert-base",
+                "ctrl",
+                "distilbert-base-cased-distilled-squad",
+                "distilbert-base-cased",
+                "distilbert-base-german-cased",
+                "distilbert-base-multilingual-cased",
+                "distilbert-base-uncased-distilled-squad",
+                "distilbert-base-uncased-finetuned-sst-2-english",
+                "distilbert-base-uncased",
+                "distilgpt2",
+                "distilroberta-base",
+                "gpt2-large",
+                "gpt2-medium",
+                "gpt2-xl",
+                "gpt2",
+                "openai-gpt",
+                "roberta-base-openai-detector",
+                "roberta-base",
+                "roberta-large-mnli",
+                "roberta-large-openai-detector",
+                "roberta-large",
+                "t5-11b",
+                "t5-3b",
+                "t5-base",
+                "t5-large",
+                "t5-small",
+                "transfo-xl-wt103",
+                "xlm-clm-ende-1024",
+                "xlm-clm-enfr-1024",
+                "xlm-mlm-100-1280",
+                "xlm-mlm-17-1280",
+                "xlm-mlm-en-2048",
+                "xlm-mlm-ende-1024",
+                "xlm-mlm-enfr-1024",
+                "xlm-mlm-enro-1024",
+                "xlm-mlm-tlm-xnli15-1024",
+                "xlm-mlm-xnli15-1024",
+                "xlm-roberta-base",
+                "xlm-roberta-large-finetuned-conll02-dutch",
+                "xlm-roberta-large-finetuned-conll02-spanish",
+                "xlm-roberta-large-finetuned-conll03-english",
+                "xlm-roberta-large-finetuned-conll03-german",
+                "xlm-roberta-large",
+                "xlnet-base-cased",
+                "xlnet-large-cased",
+            ]
+
+            if not os.path.exists(model_name_or_path):
+                # Not a path, load from hub
+                if "\\" in model_name_or_path or model_name_or_path.count("/") > 1:
+                    raise ValueError(f"Path {model_name_or_path} not found")
+
+                if "/" not in model_name_or_path and model_name_or_path.lower() not in basic_transformer_models:
+                    # A model from sentence-transformers
+                    model_name_or_path = __MODEL_HUB_ORGANIZATION__ + "/" + model_name_or_path
+
+            if is_sentence_transformer_model(
+                model_name_or_path,
+                token,
+                cache_folder=cache_folder,
+                revision=revision,
+                local_files_only=local_files_only,
+            ):
+                modules, self.module_kwargs = self._load_sbert_model(
+                    model_name_or_path,
+                    token=token,
+                    cache_folder=cache_folder,
+                    revision=revision,
+                    trust_remote_code=trust_remote_code,
+                    local_files_only=local_files_only,
+                    model_kwargs=model_kwargs,
+                    tokenizer_kwargs=tokenizer_kwargs,
+                    config_kwargs=config_kwargs,
+                )
+            else:
+                modules = self._load_auto_model(
+                    model_name_or_path,
+                    token=token,
+                    cache_folder=cache_folder,
+                    revision=revision,
+                    trust_remote_code=trust_remote_code,
+                    local_files_only=local_files_only,
+                    model_kwargs=model_kwargs,
+                    tokenizer_kwargs=tokenizer_kwargs,
+                    config_kwargs=config_kwargs,
+                )
+
+        if modules is not None and not isinstance(modules, OrderedDict):
+            modules = OrderedDict([(str(idx), module) for idx, module in enumerate(modules)])
+
+        super().__init__(modules)
+
+        # Ensure all tensors in the model are of the same dtype as the first tensor
+        # This is necessary if the first module has been given a lower precision via
+        # model_kwargs["torch_dtype"]. The rest of the model should be loaded in the same dtype
+        # See #2887 for more details
+        try:
+            dtype = next(self.parameters()).dtype
+            self.to(dtype)
+        except StopIteration:
+            pass
+
+        self.to(device)
+        self.is_hpu_graph_enabled = False
+
+        if self.default_prompt_name is not None and self.default_prompt_name not in self.prompts:
+            raise ValueError(
+                f"Default prompt name '{self.default_prompt_name}' not found in the configured prompts "
+                f"dictionary with keys {list(self.prompts.keys())!r}."
+            )
+
+        if self.prompts:
+            logger.info(f"{len(self.prompts)} prompts are loaded, with the keys: {list(self.prompts.keys())}")
+        if self.default_prompt_name:
+            logger.warning(
+                f"Default prompt name is set to '{self.default_prompt_name}'. "
+                "This prompt will be applied to all `encode()` calls, except if `encode()` "
+                "is called with `prompt` or `prompt_name` parameters."
+            )
+
+        # Ideally, INSTRUCTOR models should set `include_prompt=False` in their pooling configuration, but
+        # that would be a breaking change for users currently using the InstructorEmbedding project.
+        # So, instead we hardcode setting it for the main INSTRUCTOR models, and otherwise give a warning if we
+        # suspect the user is using an INSTRUCTOR model.
+        if model_name_or_path in ("hkunlp/instructor-base", "hkunlp/instructor-large", "hkunlp/instructor-xl"):
+            self.set_pooling_include_prompt(include_prompt=False)
+        elif (
+            model_name_or_path
+            and "/" in model_name_or_path
+            and "instructor" in model_name_or_path.split("/")[1].lower()
+        ):
+            if any([module.include_prompt for module in self if isinstance(module, Pooling)]):
+                logger.warning(
+                    "Instructor models require `include_prompt=False` in the pooling configuration. "
+                    "Either update the model configuration or call `model.set_pooling_include_prompt(False)` after loading the model."
+                )
+
+        # Pass the model to the model card data for later use in generating a model card upon saving this model
+        self.model_card_data.register_model(self)
+
+    def get_backend(self) -> Literal["torch", "onnx", "openvino"]:
+        """Return the backend used for inference, which can be one of "torch", "onnx", or "openvino".
+
+        Returns:
+            str: The backend used for inference.
+        """
+        return self.backend
+
+    @overload
+    def encode(
+        self,
+        sentences: str,
+        prompt_name: str | None = ...,
+        prompt: str | None = ...,
+        batch_size: int = ...,
+        show_progress_bar: bool | None = ...,
+        output_value: Literal["sentence_embedding", "token_embeddings"] | None = ...,
+        precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = ...,
+        convert_to_numpy: Literal[False] = ...,
+        convert_to_tensor: Literal[False] = ...,
+        device: str = ...,
+        normalize_embeddings: bool = ...,
+        **kwargs,
+    ) -> Tensor: ...
+
+    @overload
+    def encode(
+        self,
+        sentences: str | list[str],
+        prompt_name: str | None = ...,
+        prompt: str | None = ...,
+        batch_size: int = ...,
+        show_progress_bar: bool | None = ...,
+        output_value: Literal["sentence_embedding", "token_embeddings"] | None = ...,
+        precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = ...,
+        convert_to_numpy: Literal[True] = ...,
+        convert_to_tensor: Literal[False] = ...,
+        device: str = ...,
+        normalize_embeddings: bool = ...,
+        **kwargs,
+    ) -> np.ndarray: ...
+
+    @overload
+    def encode(
+        self,
+        sentences: str | list[str],
+        prompt_name: str | None = ...,
+        prompt: str | None = ...,
+        batch_size: int = ...,
+        show_progress_bar: bool | None = ...,
+        output_value: Literal["sentence_embedding", "token_embeddings"] | None = ...,
+        precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = ...,
+        convert_to_numpy: bool = ...,
+        convert_to_tensor: Literal[True] = ...,
+        device: str = ...,
+        normalize_embeddings: bool = ...,
+        **kwargs,
+    ) -> Tensor: ...
+
+    @overload
+    def encode(
+        self,
+        sentences: list[str] | np.ndarray,
+        prompt_name: str | None = ...,
+        prompt: str | None = ...,
+        batch_size: int = ...,
+        show_progress_bar: bool | None = ...,
+        output_value: Literal["sentence_embedding", "token_embeddings"] | None = ...,
+        precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = ...,
+        convert_to_numpy: Literal[False] = ...,
+        convert_to_tensor: Literal[False] = ...,
+        device: str = ...,
+        normalize_embeddings: bool = ...,
+        **kwargs,
+    ) -> list[Tensor]: ...
+
+    def encode(
+        self,
+        sentences: str | list[str],
+        prompt_name: str | None = None,
+        prompt: str | None = None,
+        batch_size: int = 32,
+        show_progress_bar: bool | None = None,
+        output_value: Literal["sentence_embedding", "token_embeddings"] | None = "sentence_embedding",
+        precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32",
+        convert_to_numpy: bool = True,
+        convert_to_tensor: bool = False,
+        device: str = None,
+        normalize_embeddings: bool = False,
+        **kwargs,
+    ) -> list[Tensor] | np.ndarray | Tensor:
+        """
+        Computes sentence embeddings.
+
+        Args:
+            sentences (Union[str, List[str]]): The sentences to embed.
+            prompt_name (Optional[str], optional): The name of the prompt to use for encoding. Must be a key in the `prompts` dictionary,
+                which is either set in the constructor or loaded from the model configuration. For example if
+                ``prompt_name`` is "query" and the ``prompts`` is {"query": "query: ", ...}, then the sentence "What
+                is the capital of France?" will be encoded as "query: What is the capital of France?" because the sentence
+                is appended to the prompt. If ``prompt`` is also set, this argument is ignored. Defaults to None.
+            prompt (Optional[str], optional): The prompt to use for encoding. For example, if the prompt is "query: ", then the
+                sentence "What is the capital of France?" will be encoded as "query: What is the capital of France?"
+                because the sentence is appended to the prompt. If ``prompt`` is set, ``prompt_name`` is ignored. Defaults to None.
+            batch_size (int, optional): The batch size used for the computation. Defaults to 32.
+            show_progress_bar (bool, optional): Whether to output a progress bar when encode sentences. Defaults to None.
+            output_value (Optional[Literal["sentence_embedding", "token_embeddings"]], optional): The type of embeddings to return:
+                "sentence_embedding" to get sentence embeddings, "token_embeddings" to get wordpiece token embeddings, and `None`,
+                to get all output values. Defaults to "sentence_embedding".
+            precision (Literal["float32", "int8", "uint8", "binary", "ubinary"], optional): The precision to use for the embeddings.
+                Can be "float32", "int8", "uint8", "binary", or "ubinary". All non-float32 precisions are quantized embeddings.
+                Quantized embeddings are smaller in size and faster to compute, but may have a lower accuracy. They are useful for
+                reducing the size of the embeddings of a corpus for semantic search, among other tasks. Defaults to "float32".
+            convert_to_numpy (bool, optional): Whether the output should be a list of numpy vectors. If False, it is a list of PyTorch tensors.
+                Defaults to True.
+            convert_to_tensor (bool, optional): Whether the output should be one large tensor. Overwrites `convert_to_numpy`.
+                Defaults to False.
+            device (str, optional): Which :class:`torch.device` to use for the computation. Defaults to None.
+            normalize_embeddings (bool, optional): Whether to normalize returned vectors to have length 1. In that case,
+                the faster dot-product (util.dot_score) instead of cosine similarity can be used. Defaults to False.
+
+        Returns:
+            Union[List[Tensor], ndarray, Tensor]: By default, a 2d numpy array with shape [num_inputs, output_dimension] is returned.
+            If only one string input is provided, then the output is a 1d array with shape [output_dimension]. If ``convert_to_tensor``,
+            a torch Tensor is returned instead. If ``self.truncate_dim <= output_dimension`` then output_dimension is ``self.truncate_dim``.
+
+        Example:
+            ::
+
+                from sentence_transformers import SentenceTransformer
+
+                # Load a pre-trained SentenceTransformer model
+                model = SentenceTransformer('all-mpnet-base-v2')
+
+                # Encode some texts
+                sentences = [
+                    "The weather is lovely today.",
+                    "It's so sunny outside!",
+                    "He drove to the stadium.",
+                ]
+                embeddings = model.encode(sentences)
+                print(embeddings.shape)
+                # (3, 768)
+        """
+        if self.device.type == "hpu" and not self.is_hpu_graph_enabled:
+            import habana_frameworks.torch as ht
+
+            ht.hpu.wrap_in_hpu_graph(self, disable_tensor_cache=True)
+            self.is_hpu_graph_enabled = True
+
+        self.eval()
+        if show_progress_bar is None:
+            show_progress_bar = logger.getEffectiveLevel() in (logging.INFO, logging.DEBUG)
+
+        if convert_to_tensor:
+            convert_to_numpy = False
+
+        if output_value != "sentence_embedding":
+            convert_to_tensor = False
+            convert_to_numpy = False
+
+        input_was_string = False
+        if isinstance(sentences, str) or not hasattr(
+            sentences, "__len__"
+        ):  # Cast an individual sentence to a list with length 1
+            sentences = [sentences]
+            input_was_string = True
+
+        if prompt is None:
+            if prompt_name is not None:
+                try:
+                    prompt = self.prompts[prompt_name]
+                except KeyError:
+                    raise ValueError(
+                        f"Prompt name '{prompt_name}' not found in the configured prompts dictionary with keys {list(self.prompts.keys())!r}."
+                    )
+            elif self.default_prompt_name is not None:
+                prompt = self.prompts.get(self.default_prompt_name, None)
+        else:
+            if prompt_name is not None:
+                logger.warning(
+                    "Encode with either a `prompt`, a `prompt_name`, or neither, but not both. "
+                    "Ignoring the `prompt_name` in favor of `prompt`."
+                )
+
+        extra_features = {}
+        if prompt is not None:
+            sentences = [prompt + sentence for sentence in sentences]
+
+            # Some models (e.g. INSTRUCTOR, GRIT) require removing the prompt before pooling
+            # Tracking the prompt length allow us to remove the prompt during pooling
+            tokenized_prompt = self.tokenize([prompt])
+            if "input_ids" in tokenized_prompt:
+                extra_features["prompt_length"] = tokenized_prompt["input_ids"].shape[-1] - 1
+
+        if device is None:
+            device = self.device
+
+        self.to(device)
+
+        all_embeddings = []
+        length_sorted_idx = np.argsort([-self._text_length(sen) for sen in sentences])
+        sentences_sorted = [sentences[idx] for idx in length_sorted_idx]
+
+        for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar):
+            sentences_batch = sentences_sorted[start_index : start_index + batch_size]
+            features = self.tokenize(sentences_batch)
+            if self.device.type == "hpu":
+                if "input_ids" in features:
+                    curr_tokenize_len = features["input_ids"].shape
+                    additional_pad_len = 2 ** math.ceil(math.log2(curr_tokenize_len[1])) - curr_tokenize_len[1]
+                    features["input_ids"] = torch.cat(
+                        (
+                            features["input_ids"],
+                            torch.ones((curr_tokenize_len[0], additional_pad_len), dtype=torch.int8),
+                        ),
+                        -1,
+                    )
+                    features["attention_mask"] = torch.cat(
+                        (
+                            features["attention_mask"],
+                            torch.zeros((curr_tokenize_len[0], additional_pad_len), dtype=torch.int8),
+                        ),
+                        -1,
+                    )
+                    if "token_type_ids" in features:
+                        features["token_type_ids"] = torch.cat(
+                            (
+                                features["token_type_ids"],
+                                torch.zeros((curr_tokenize_len[0], additional_pad_len), dtype=torch.int8),
+                            ),
+                            -1,
+                        )
+
+            features = batch_to_device(features, device)
+            features.update(extra_features)
+
+            with torch.no_grad():
+                out_features = self.forward(features, **kwargs)
+                if self.device.type == "hpu":
+                    out_features = copy.deepcopy(out_features)
+
+                out_features["sentence_embedding"] = truncate_embeddings(
+                    out_features["sentence_embedding"], self.truncate_dim
+                )
+
+                if output_value == "token_embeddings":
+                    embeddings = []
+                    for token_emb, attention in zip(out_features[output_value], out_features["attention_mask"]):
+                        last_mask_id = len(attention) - 1
+                        while last_mask_id > 0 and attention[last_mask_id].item() == 0:
+                            last_mask_id -= 1
+
+                        embeddings.append(token_emb[0 : last_mask_id + 1])
+                elif output_value is None:  # Return all outputs
+                    embeddings = []
+                    for sent_idx in range(len(out_features["sentence_embedding"])):
+                        row = {name: out_features[name][sent_idx] for name in out_features}
+                        embeddings.append(row)
+                else:  # Sentence embeddings
+                    embeddings = out_features[output_value]
+                    embeddings = embeddings.detach()
+                    if normalize_embeddings:
+                        embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
+
+                    # fixes for #522 and #487 to avoid oom problems on gpu with large datasets
+                    if convert_to_numpy:
+                        embeddings = embeddings.cpu()
+
+                all_embeddings.extend(embeddings)
+
+        all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)]
+
+        if precision and precision != "float32":
+            all_embeddings = quantize_embeddings(all_embeddings, precision=precision)
+
+        if convert_to_tensor:
+            if len(all_embeddings):
+                if isinstance(all_embeddings, np.ndarray):
+                    all_embeddings = torch.from_numpy(all_embeddings)
+                else:
+                    all_embeddings = torch.stack(all_embeddings)
+            else:
+                all_embeddings = torch.Tensor()
+        elif convert_to_numpy:
+            if not isinstance(all_embeddings, np.ndarray):
+                if all_embeddings and all_embeddings[0].dtype == torch.bfloat16:
+                    all_embeddings = np.asarray([emb.float().numpy() for emb in all_embeddings])
+                else:
+                    all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
+        elif isinstance(all_embeddings, np.ndarray):
+            all_embeddings = [torch.from_numpy(embedding) for embedding in all_embeddings]
+
+        if input_was_string:
+            all_embeddings = all_embeddings[0]
+
+        return all_embeddings
+
+    def forward(self, input: dict[str, Tensor], **kwargs) -> dict[str, Tensor]:
+        if self.module_kwargs is None:
+            return super().forward(input)
+
+        for module_name, module in self.named_children():
+            module_kwarg_keys = self.module_kwargs.get(module_name, [])
+            module_kwargs = {key: value for key, value in kwargs.items() if key in module_kwarg_keys}
+            input = module(input, **module_kwargs)
+        return input
+
+    @property
+    def similarity_fn_name(self) -> Literal["cosine", "dot", "euclidean", "manhattan"]:
+        """Return the name of the similarity function used by :meth:`SentenceTransformer.similarity` and :meth:`SentenceTransformer.similarity_pairwise`.
+
+        Returns:
+            Optional[str]: The name of the similarity function. Can be None if not set, in which case it will
+                default to "cosine" when first called.
+
+        Example:
+            >>> model = SentenceTransformer("multi-qa-mpnet-base-dot-v1")
+            >>> model.similarity_fn_name
+            'dot'
+        """
+        if self._similarity_fn_name is None:
+            self.similarity_fn_name = SimilarityFunction.COSINE
+        return self._similarity_fn_name
+
+    @similarity_fn_name.setter
+    def similarity_fn_name(
+        self, value: Literal["cosine", "dot", "euclidean", "manhattan"] | SimilarityFunction
+    ) -> None:
+        if isinstance(value, SimilarityFunction):
+            value = value.value
+        self._similarity_fn_name = value
+
+        if value is not None:
+            self._similarity = SimilarityFunction.to_similarity_fn(value)
+            self._similarity_pairwise = SimilarityFunction.to_similarity_pairwise_fn(value)
+
+    @overload
+    def similarity(self, embeddings1: Tensor, embeddings2: Tensor) -> Tensor: ...
+
+    @overload
+    def similarity(self, embeddings1: ndarray, embeddings2: ndarray) -> Tensor: ...
+
+    @property
+    def similarity(self) -> Callable[[Tensor | ndarray, Tensor | ndarray], Tensor]:
+        """
+        Compute the similarity between two collections of embeddings. The output will be a matrix with the similarity
+        scores between all embeddings from the first parameter and all embeddings from the second parameter. This
+        differs from `similarity_pairwise` which computes the similarity between each pair of embeddings.
+
+        Args:
+            embeddings1 (Union[Tensor, ndarray]): [num_embeddings_1, embedding_dim] or [embedding_dim]-shaped numpy array or torch tensor.
+            embeddings2 (Union[Tensor, ndarray]): [num_embeddings_2, embedding_dim] or [embedding_dim]-shaped numpy array or torch tensor.
+
+        Returns:
+            Tensor: A [num_embeddings_1, num_embeddings_2]-shaped torch tensor with similarity scores.
+
+        Example:
+            ::
+
+                >>> model = SentenceTransformer("all-mpnet-base-v2")
+                >>> sentences = [
+                ...     "The weather is so nice!",
+                ...     "It's so sunny outside.",
+                ...     "He's driving to the movie theater.",
+                ...     "She's going to the cinema.",
+                ... ]
+                >>> embeddings = model.encode(sentences, normalize_embeddings=True)
+                >>> model.similarity(embeddings, embeddings)
+                tensor([[1.0000, 0.7235, 0.0290, 0.1309],
+                        [0.7235, 1.0000, 0.0613, 0.1129],
+                        [0.0290, 0.0613, 1.0000, 0.5027],
+                        [0.1309, 0.1129, 0.5027, 1.0000]])
+                >>> model.similarity_fn_name
+                "cosine"
+                >>> model.similarity_fn_name = "euclidean"
+                >>> model.similarity(embeddings, embeddings)
+                tensor([[-0.0000, -0.7437, -1.3935, -1.3184],
+                        [-0.7437, -0.0000, -1.3702, -1.3320],
+                        [-1.3935, -1.3702, -0.0000, -0.9973],
+                        [-1.3184, -1.3320, -0.9973, -0.0000]])
+        """
+        if self.similarity_fn_name is None:
+            self.similarity_fn_name = SimilarityFunction.COSINE
+        return self._similarity
+
+    @overload
+    def similarity_pairwise(self, embeddings1: Tensor, embeddings2: Tensor) -> Tensor: ...
+
+    @overload
+    def similarity_pairwise(self, embeddings1: ndarray, embeddings2: ndarray) -> Tensor: ...
+
+    @property
+    def similarity_pairwise(self) -> Callable[[Tensor | ndarray, Tensor | ndarray], Tensor]:
+        """
+        Compute the similarity between two collections of embeddings. The output will be a vector with the similarity
+        scores between each pair of embeddings.
+
+        Args:
+            embeddings1 (Union[Tensor, ndarray]): [num_embeddings, embedding_dim] or [embedding_dim]-shaped numpy array or torch tensor.
+            embeddings2 (Union[Tensor, ndarray]): [num_embeddings, embedding_dim] or [embedding_dim]-shaped numpy array or torch tensor.
+
+        Returns:
+            Tensor: A [num_embeddings]-shaped torch tensor with pairwise similarity scores.
+
+        Example:
+            ::
+
+                >>> model = SentenceTransformer("all-mpnet-base-v2")
+                >>> sentences = [
+                ...     "The weather is so nice!",
+                ...     "It's so sunny outside.",
+                ...     "He's driving to the movie theater.",
+                ...     "She's going to the cinema.",
+                ... ]
+                >>> embeddings = model.encode(sentences, normalize_embeddings=True)
+                >>> model.similarity_pairwise(embeddings[::2], embeddings[1::2])
+                tensor([0.7235, 0.5027])
+                >>> model.similarity_fn_name
+                "cosine"
+                >>> model.similarity_fn_name = "euclidean"
+                >>> model.similarity_pairwise(embeddings[::2], embeddings[1::2])
+                tensor([-0.7437, -0.9973])
+        """
+        if self.similarity_fn_name is None:
+            self.similarity_fn_name = SimilarityFunction.COSINE
+        return self._similarity_pairwise
+
+    def start_multi_process_pool(
+        self, target_devices: list[str] = None
+    ) -> dict[Literal["input", "output", "processes"], Any]:
+        """
+        Starts a multi-process pool to process the encoding with several independent processes
+        via :meth:`SentenceTransformer.encode_multi_process <sentence_transformers.SentenceTransformer.encode_multi_process>`.
+
+        This method is recommended if you want to encode on multiple GPUs or CPUs. It is advised
+        to start only one process per GPU. This method works together with encode_multi_process
+        and stop_multi_process_pool.
+
+        Args:
+            target_devices (List[str], optional): PyTorch target devices, e.g. ["cuda:0", "cuda:1", ...],
+                ["npu:0", "npu:1", ...], or ["cpu", "cpu", "cpu", "cpu"]. If target_devices is None and CUDA/NPU
+                is available, then all available CUDA/NPU devices will be used. If target_devices is None and
+                CUDA/NPU is not available, then 4 CPU devices will be used.
+
+        Returns:
+            Dict[str, Any]: A dictionary with the target processes, an input queue, and an output queue.
+        """
+        if target_devices is None:
+            if torch.cuda.is_available():
+                target_devices = [f"cuda:{i}" for i in range(torch.cuda.device_count())]
+            elif is_torch_npu_available():
+                target_devices = [f"npu:{i}" for i in range(torch.npu.device_count())]
+            else:
+                logger.info("CUDA/NPU is not available. Starting 4 CPU workers")
+                target_devices = ["cpu"] * 4
+
+        logger.info("Start multi-process pool on devices: {}".format(", ".join(map(str, target_devices))))
+
+        self.to("cpu")
+        self.share_memory()
+        ctx = mp.get_context("spawn")
+        input_queue = ctx.Queue()
+        output_queue = ctx.Queue()
+        processes = []
+
+        for device_id in target_devices:
+            p = ctx.Process(
+                target=SentenceTransformer._encode_multi_process_worker,
+                args=(device_id, self, input_queue, output_queue),
+                daemon=True,
+            )
+            p.start()
+            processes.append(p)
+
+        return {"input": input_queue, "output": output_queue, "processes": processes}
+
+    @staticmethod
+    def stop_multi_process_pool(pool: dict[Literal["input", "output", "processes"], Any]) -> None:
+        """
+        Stops all processes started with start_multi_process_pool.
+
+        Args:
+            pool (Dict[str, object]): A dictionary containing the input queue, output queue, and process list.
+
+        Returns:
+            None
+        """
+        for p in pool["processes"]:
+            p.terminate()
+
+        for p in pool["processes"]:
+            p.join()
+            p.close()
+
+        pool["input"].close()
+        pool["output"].close()
+
+    def encode_multi_process(
+        self,
+        sentences: list[str],
+        pool: dict[Literal["input", "output", "processes"], Any],
+        prompt_name: str | None = None,
+        prompt: str | None = None,
+        batch_size: int = 32,
+        chunk_size: int = None,
+        show_progress_bar: bool | None = None,
+        precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32",
+        normalize_embeddings: bool = False,
+    ) -> np.ndarray:
+        """
+        Encodes a list of sentences using multiple processes and GPUs via
+        :meth:`SentenceTransformer.encode <sentence_transformers.SentenceTransformer.encode>`.
+        The sentences are chunked into smaller packages and sent to individual processes, which encode them on different
+        GPUs or CPUs. This method is only suitable for encoding large sets of sentences.
+
+        Args:
+            sentences (List[str]): List of sentences to encode.
+            pool (Dict[Literal["input", "output", "processes"], Any]): A pool of workers started with
+                :meth:`SentenceTransformer.start_multi_process_pool <sentence_transformers.SentenceTransformer.start_multi_process_pool>`.
+            prompt_name (Optional[str], optional): The name of the prompt to use for encoding. Must be a key in the `prompts` dictionary,
+                which is either set in the constructor or loaded from the model configuration. For example if
+                ``prompt_name`` is "query" and the ``prompts`` is {"query": "query: ", ...}, then the sentence "What
+                is the capital of France?" will be encoded as "query: What is the capital of France?" because the sentence
+                is appended to the prompt. If ``prompt`` is also set, this argument is ignored. Defaults to None.
+            prompt (Optional[str], optional): The prompt to use for encoding. For example, if the prompt is "query: ", then the
+                sentence "What is the capital of France?" will be encoded as "query: What is the capital of France?"
+                because the sentence is appended to the prompt. If ``prompt`` is set, ``prompt_name`` is ignored. Defaults to None.
+            batch_size (int): Encode sentences with batch size. (default: 32)
+            chunk_size (int): Sentences are chunked and sent to the individual processes. If None, it determines a
+                sensible size. Defaults to None.
+            show_progress_bar (bool, optional): Whether to output a progress bar when encode sentences. Defaults to None.
+            precision (Literal["float32", "int8", "uint8", "binary", "ubinary"]): The precision to use for the
+                embeddings. Can be "float32", "int8", "uint8", "binary", or "ubinary". All non-float32 precisions
+                are quantized embeddings. Quantized embeddings are smaller in size and faster to compute, but may
+                have lower accuracy. They are useful for reducing the size of the embeddings of a corpus for
+                semantic search, among other tasks. Defaults to "float32".
+            normalize_embeddings (bool): Whether to normalize returned vectors to have length 1. In that case,
+                the faster dot-product (util.dot_score) instead of cosine similarity can be used. Defaults to False.
+
+        Returns:
+            np.ndarray: A 2D numpy array with shape [num_inputs, output_dimension].
+
+        Example:
+            ::
+
+                from sentence_transformers import SentenceTransformer
+
+                def main():
+                    model = SentenceTransformer("all-mpnet-base-v2")
+                    sentences = ["The weather is so nice!", "It's so sunny outside.", "He's driving to the movie theater.", "She's going to the cinema."] * 1000
+
+                    pool = model.start_multi_process_pool()
+                    embeddings = model.encode_multi_process(sentences, pool)
+                    model.stop_multi_process_pool(pool)
+
+                    print(embeddings.shape)
+                    # => (4000, 768)
+
+                if __name__ == "__main__":
+                    main()
+        """
+
+        if chunk_size is None:
+            chunk_size = min(math.ceil(len(sentences) / len(pool["processes"]) / 10), 5000)
+
+        if show_progress_bar is None:
+            show_progress_bar = logger.getEffectiveLevel() in (logging.INFO, logging.DEBUG)
+
+        logger.debug(f"Chunk data into {math.ceil(len(sentences) / chunk_size)} packages of size {chunk_size}")
+
+        input_queue = pool["input"]
+        last_chunk_id = 0
+        chunk = []
+
+        for sentence in sentences:
+            chunk.append(sentence)
+            if len(chunk) >= chunk_size:
+                input_queue.put(
+                    [last_chunk_id, batch_size, chunk, prompt_name, prompt, precision, normalize_embeddings]
+                )
+                last_chunk_id += 1
+                chunk = []
+
+        if len(chunk) > 0:
+            input_queue.put([last_chunk_id, batch_size, chunk, prompt_name, prompt, precision, normalize_embeddings])
+            last_chunk_id += 1
+
+        output_queue = pool["output"]
+        results_list = sorted(
+            [output_queue.get() for _ in trange(last_chunk_id, desc="Chunks", disable=not show_progress_bar)],
+            key=lambda x: x[0],
+        )
+        embeddings = np.concatenate([result[1] for result in results_list])
+        return embeddings
+
+    @staticmethod
+    def _encode_multi_process_worker(
+        target_device: str, model: SentenceTransformer, input_queue: Queue, results_queue: Queue
+    ) -> None:
+        """
+        Internal working process to encode sentences in multi-process setup
+        """
+        while True:
+            try:
+                chunk_id, batch_size, sentences, prompt_name, prompt, precision, normalize_embeddings = (
+                    input_queue.get()
+                )
+                embeddings = model.encode(
+                    sentences,
+                    prompt_name=prompt_name,
+                    prompt=prompt,
+                    device=target_device,
+                    show_progress_bar=False,
+                    precision=precision,
+                    convert_to_numpy=True,
+                    batch_size=batch_size,
+                    normalize_embeddings=normalize_embeddings,
+                )
+
+                results_queue.put([chunk_id, embeddings])
+            except queue.Empty:
+                break
+
+    def set_pooling_include_prompt(self, include_prompt: bool) -> None:
+        """
+        Sets the `include_prompt` attribute in the pooling layer in the model, if there is one.
+
+        This is useful for INSTRUCTOR models, as the prompt should be excluded from the pooling strategy
+        for these models.
+
+        Args:
+            include_prompt (bool): Whether to include the prompt in the pooling layer.
+
+        Returns:
+            None
+        """
+        for module in self:
+            if isinstance(module, Pooling):
+                module.include_prompt = include_prompt
+                break
+
+    def get_max_seq_length(self) -> int | None:
+        """
+        Returns the maximal sequence length that the model accepts. Longer inputs will be truncated.
+
+        Returns:
+            Optional[int]: The maximal sequence length that the model accepts, or None if it is not defined.
+        """
+        if hasattr(self._first_module(), "max_seq_length"):
+            return self._first_module().max_seq_length
+
+        return None
+
+    def tokenize(self, texts: list[str] | list[dict] | list[tuple[str, str]]) -> dict[str, Tensor]:
+        """
+        Tokenizes the texts.
+
+        Args:
+            texts (Union[List[str], List[Dict], List[Tuple[str, str]]]): A list of texts to be tokenized.
+
+        Returns:
+            Dict[str, Tensor]: A dictionary of tensors with the tokenized texts. Common keys are "input_ids",
+                "attention_mask", and "token_type_ids".
+        """
+        return self._first_module().tokenize(texts)
+
+    def get_sentence_features(self, *features) -> dict[Literal["sentence_embedding"], Tensor]:
+        return self._first_module().get_sentence_features(*features)
+
+    def get_sentence_embedding_dimension(self) -> int | None:
+        """
+        Returns the number of dimensions in the output of :meth:`SentenceTransformer.encode <sentence_transformers.SentenceTransformer.encode>`.
+
+        Returns:
+            Optional[int]: The number of dimensions in the output of `encode`. If it's not known, it's `None`.
+        """
+        output_dim = None
+        for mod in reversed(self._modules.values()):
+            sent_embedding_dim_method = getattr(mod, "get_sentence_embedding_dimension", None)
+            if callable(sent_embedding_dim_method):
+                output_dim = sent_embedding_dim_method()
+                break
+        if self.truncate_dim is not None:
+            # The user requested truncation. If they set it to a dim greater than output_dim,
+            # no truncation will actually happen. So return output_dim instead of self.truncate_dim
+            return min(output_dim or np.inf, self.truncate_dim)
+        return output_dim
+
+    @contextmanager
+    def truncate_sentence_embeddings(self, truncate_dim: int | None) -> Iterator[None]:
+        """
+        In this context, :meth:`SentenceTransformer.encode <sentence_transformers.SentenceTransformer.encode>` outputs
+        sentence embeddings truncated at dimension ``truncate_dim``.
+
+        This may be useful when you are using the same model for different applications where different dimensions
+        are needed.
+
+        Args:
+            truncate_dim (int, optional): The dimension to truncate sentence embeddings to. ``None`` does no truncation.
+
+        Example:
+            ::
+
+                from sentence_transformers import SentenceTransformer
+
+                model = SentenceTransformer("all-mpnet-base-v2")
+
+                with model.truncate_sentence_embeddings(truncate_dim=16):
+                    embeddings_truncated = model.encode(["hello there", "hiya"])
+                assert embeddings_truncated.shape[-1] == 16
+        """
+        original_output_dim = self.truncate_dim
+        try:
+            self.truncate_dim = truncate_dim
+            yield
+        finally:
+            self.truncate_dim = original_output_dim
+
+    def _first_module(self) -> torch.nn.Module:
+        """Returns the first module of this sequential embedder"""
+        return self._modules[next(iter(self._modules))]
+
+    def _last_module(self) -> torch.nn.Module:
+        """Returns the last module of this sequential embedder"""
+        return self._modules[next(reversed(self._modules))]
+
+    def save(
+        self,
+        path: str,
+        model_name: str | None = None,
+        create_model_card: bool = True,
+        train_datasets: list[str] | None = None,
+        safe_serialization: bool = True,
+    ) -> None:
+        """
+        Saves a model and its configuration files to a directory, so that it can be loaded
+        with ``SentenceTransformer(path)`` again.
+
+        Args:
+            path (str): Path on disc where the model will be saved.
+            model_name (str, optional): Optional model name.
+            create_model_card (bool, optional): If True, create a README.md with basic information about this model.
+            train_datasets (List[str], optional): Optional list with the names of the datasets used to train the model.
+            safe_serialization (bool, optional): If True, save the model using safetensors. If False, save the model
+                the traditional (but unsafe) PyTorch way.
+        """
+        if path is None:
+            return
+
+        os.makedirs(path, exist_ok=True)
+
+        logger.info(f"Save model to {path}")
+        modules_config = []
+
+        # Save some model info
+        self._model_config["__version__"] = {
+            "sentence_transformers": __version__,
+            "transformers": transformers.__version__,
+            "pytorch": torch.__version__,
+        }
+
+        with open(os.path.join(path, "config_sentence_transformers.json"), "w") as fOut:
+            config = self._model_config.copy()
+            config["prompts"] = self.prompts
+            config["default_prompt_name"] = self.default_prompt_name
+            config["similarity_fn_name"] = self.similarity_fn_name
+            json.dump(config, fOut, indent=2)
+
+        # Save modules
+        for idx, name in enumerate(self._modules):
+            module = self._modules[name]
+            if idx == 0 and hasattr(module, "save_in_root"):  # Save first module in the main folder
+                model_path = path + "/"
+            else:
+                model_path = os.path.join(path, str(idx) + "_" + type(module).__name__)
+
+            os.makedirs(model_path, exist_ok=True)
+            # Try to save with safetensors, but fall back to the traditional PyTorch way if the module doesn't support it
+            try:
+                module.save(model_path, safe_serialization=safe_serialization)
+            except TypeError:
+                module.save(model_path)
+
+            # "module" only works for Sentence Transformers as the modules have the same names as the classes
+            class_ref = type(module).__module__
+            # For remote modules, we want to remove "transformers_modules.{repo_name}":
+            if class_ref.startswith("transformers_modules."):
+                class_file = sys.modules[class_ref].__file__
+
+                # Save the custom module file
+                dest_file = Path(model_path) / (Path(class_file).name)
+                shutil.copy(class_file, dest_file)
+
+                # Save all files importeed in the custom module file
+                for needed_file in get_relative_import_files(class_file):
+                    dest_file = Path(model_path) / (Path(needed_file).name)
+                    shutil.copy(needed_file, dest_file)
+
+                # For remote modules, we want to ignore the "transformers_modules.{repo_id}" part,
+                # i.e. we only want the filename
+                class_ref = f"{class_ref.split('.')[-1]}.{type(module).__name__}"
+            # For other cases, we want to add the class name:
+            elif not class_ref.startswith("sentence_transformers."):
+                class_ref = f"{class_ref}.{type(module).__name__}"
+            modules_config.append({"idx": idx, "name": name, "path": os.path.basename(model_path), "type": class_ref})
+
+        with open(os.path.join(path, "modules.json"), "w") as fOut:
+            json.dump(modules_config, fOut, indent=2)
+
+        # Create model card
+        if create_model_card:
+            self._create_model_card(path, model_name, train_datasets)
+
+    def save_pretrained(
+        self,
+        path: str,
+        model_name: str | None = None,
+        create_model_card: bool = True,
+        train_datasets: list[str] | None = None,
+        safe_serialization: bool = True,
+    ) -> None:
+        """
+        Saves a model and its configuration files to a directory, so that it can be loaded
+        with ``SentenceTransformer(path)`` again.
+
+        Args:
+            path (str): Path on disc where the model will be saved.
+            model_name (str, optional): Optional model name.
+            create_model_card (bool, optional): If True, create a README.md with basic information about this model.
+            train_datasets (List[str], optional): Optional list with the names of the datasets used to train the model.
+            safe_serialization (bool, optional): If True, save the model using safetensors. If False, save the model
+                the traditional (but unsafe) PyTorch way.
+        """
+        self.save(
+            path,
+            model_name=model_name,
+            create_model_card=create_model_card,
+            train_datasets=train_datasets,
+            safe_serialization=safe_serialization,
+        )
+
+    def _create_model_card(
+        self, path: str, model_name: str | None = None, train_datasets: list[str] | None = "deprecated"
+    ) -> None:
+        """
+        Create an automatic model and stores it in the specified path. If no training was done and the loaded model
+        was a Sentence Transformer model already, then its model card is reused.
+
+        Args:
+            path (str): The path where the model card will be stored.
+            model_name (Optional[str], optional): The name of the model. Defaults to None.
+            train_datasets (Optional[List[str]], optional): Deprecated argument. Defaults to "deprecated".
+
+        Returns:
+            None
+        """
+        if model_name:
+            model_path = Path(model_name)
+            if not model_path.exists() and not self.model_card_data.model_id:
+                self.model_card_data.model_id = model_name
+
+        # If we loaded a Sentence Transformer model from the Hub, and no training was done, then
+        # we don't generate a new model card, but reuse the old one instead.
+        if self._model_card_text and self.model_card_data.trainer is None:
+            model_card = self._model_card_text
+            if self.model_card_data.model_id:
+                # If the original model card was saved without a model_id, we replace the model_id with the new model_id
+                model_card = model_card.replace(
+                    'model = SentenceTransformer("sentence_transformers_model_id"',
+                    f'model = SentenceTransformer("{self.model_card_data.model_id}"',
+                )
+        else:
+            try:
+                model_card = generate_model_card(self)
+            except Exception:
+                logger.error(
+                    f"Error while generating model card:\n{traceback.format_exc()}"
+                    "Consider opening an issue on https://github.com/UKPLab/sentence-transformers/issues with this traceback.\n"
+                    "Skipping model card creation."
+                )
+                return
+
+        with open(os.path.join(path, "README.md"), "w", encoding="utf8") as fOut:
+            fOut.write(model_card)
+
+    @save_to_hub_args_decorator
+    def save_to_hub(
+        self,
+        repo_id: str,
+        organization: str | None = None,
+        token: str | None = None,
+        private: bool | None = None,
+        safe_serialization: bool = True,
+        commit_message: str = "Add new SentenceTransformer model.",
+        local_model_path: str | None = None,
+        exist_ok: bool = False,
+        replace_model_card: bool = False,
+        train_datasets: list[str] | None = None,
+    ) -> str:
+        """
+        DEPRECATED, use `push_to_hub` instead.
+
+        Uploads all elements of this Sentence Transformer to a new HuggingFace Hub repository.
+
+        Args:
+            repo_id (str): Repository name for your model in the Hub, including the user or organization.
+            token (str, optional): An authentication token (See https://huggingface.co/settings/token)
+            private (bool, optional): Set to true, for hosting a private model
+            safe_serialization (bool, optional): If true, save the model using safetensors. If false, save the model the traditional PyTorch way
+            commit_message (str, optional): Message to commit while pushing.
+            local_model_path (str, optional): Path of the model locally. If set, this file path will be uploaded. Otherwise, the current model will be uploaded
+            exist_ok (bool, optional): If true, saving to an existing repository is OK. If false, saving only to a new repository is possible
+            replace_model_card (bool, optional): If true, replace an existing model card in the hub with the automatically created model card
+            train_datasets (List[str], optional): Datasets used to train the model. If set, the datasets will be added to the model card in the Hub.
+
+        Returns:
+            str: The url of the commit of your model in the repository on the Hugging Face Hub.
+        """
+        logger.warning(
+            "The `save_to_hub` method is deprecated and will be removed in a future version of SentenceTransformers."
+            " Please use `push_to_hub` instead for future model uploads."
+        )
+
+        if organization:
+            if "/" not in repo_id:
+                logger.warning(
+                    f'Providing an `organization` to `save_to_hub` is deprecated, please use `repo_id="{organization}/{repo_id}"` instead.'
+                )
+                repo_id = f"{organization}/{repo_id}"
+            elif repo_id.split("/")[0] != organization:
+                raise ValueError(
+                    "Providing an `organization` to `save_to_hub` is deprecated, please only use `repo_id`."
+                )
+            else:
+                logger.warning(
+                    f'Providing an `organization` to `save_to_hub` is deprecated, please only use `repo_id="{repo_id}"` instead.'
+                )
+
+        return self.push_to_hub(
+            repo_id=repo_id,
+            token=token,
+            private=private,
+            safe_serialization=safe_serialization,
+            commit_message=commit_message,
+            local_model_path=local_model_path,
+            exist_ok=exist_ok,
+            replace_model_card=replace_model_card,
+            train_datasets=train_datasets,
+        )
+
+    def push_to_hub(
+        self,
+        repo_id: str,
+        token: str | None = None,
+        private: bool | None = None,
+        safe_serialization: bool = True,
+        commit_message: str | None = None,
+        local_model_path: str | None = None,
+        exist_ok: bool = False,
+        replace_model_card: bool = False,
+        train_datasets: list[str] | None = None,
+        revision: str | None = None,
+        create_pr: bool = False,
+    ) -> str:
+        """
+        Uploads all elements of this Sentence Transformer to a new HuggingFace Hub repository.
+
+        Args:
+            repo_id (str): Repository name for your model in the Hub, including the user or organization.
+            token (str, optional): An authentication token (See https://huggingface.co/settings/token)
+            private (bool, optional): Set to true, for hosting a private model
+            safe_serialization (bool, optional): If true, save the model using safetensors. If false, save the model the traditional PyTorch way
+            commit_message (str, optional): Message to commit while pushing.
+            local_model_path (str, optional): Path of the model locally. If set, this file path will be uploaded. Otherwise, the current model will be uploaded
+            exist_ok (bool, optional): If true, saving to an existing repository is OK. If false, saving only to a new repository is possible
+            replace_model_card (bool, optional): If true, replace an existing model card in the hub with the automatically created model card
+            train_datasets (List[str], optional): Datasets used to train the model. If set, the datasets will be added to the model card in the Hub.
+            revision (str, optional): Branch to push the uploaded files to
+            create_pr (bool, optional): If True, create a pull request instead of pushing directly to the main branch
+
+        Returns:
+            str: The url of the commit of your model in the repository on the Hugging Face Hub.
+        """
+        api = HfApi(token=token)
+        repo_url = api.create_repo(
+            repo_id=repo_id,
+            private=private,
+            repo_type=None,
+            exist_ok=exist_ok or create_pr,
+        )
+        repo_id = repo_url.repo_id  # Update the repo_id in case the old repo_id didn't contain a user or organization
+        self.model_card_data.set_model_id(repo_id)
+        if revision is not None:
+            api.create_branch(repo_id=repo_id, branch=revision, exist_ok=True)
+
+        if commit_message is None:
+            backend = self.get_backend()
+            if backend == "torch":
+                commit_message = "Add new SentenceTransformer model"
+            else:
+                commit_message = f"Add new SentenceTransformer model with an {backend} backend"
+
+        commit_description = ""
+        if create_pr:
+            commit_description = f"""\
+Hello!
+
+*This pull request has been automatically generated from the [`push_to_hub`](https://sbert.net/docs/package_reference/sentence_transformer/SentenceTransformer.html#sentence_transformers.SentenceTransformer.push_to_hub) method from the Sentence Transformers library.*
+
+## Full Model Architecture:
+```
+{self}
+```
+
+## Tip:
+Consider testing this pull request before merging by loading the model from this PR with the `revision` argument:
+```python
+from sentence_transformers import SentenceTransformer
+
+# TODO: Fill in the PR number
+pr_number = 2
+model = SentenceTransformer(
+    "{repo_id}",
+    revision=f"refs/pr/{{pr_number}}",
+    backend="{self.get_backend()}",
+)
+
+# Verify that everything works as expected
+embeddings = model.encode(["The weather is lovely today.", "It's so sunny outside!", "He drove to the stadium."])
+print(embeddings.shape)
+
+similarities = model.similarity(embeddings, embeddings)
+print(similarities)
+```
+"""
+
+        if local_model_path:
+            folder_url = api.upload_folder(
+                repo_id=repo_id,
+                folder_path=local_model_path,
+                commit_message=commit_message,
+                commit_description=commit_description,
+                revision=revision,
+                create_pr=create_pr,
+            )
+        else:
+            with tempfile.TemporaryDirectory() as tmp_dir:
+                create_model_card = replace_model_card or not os.path.exists(os.path.join(tmp_dir, "README.md"))
+                self.save_pretrained(
+                    tmp_dir,
+                    model_name=repo_url.repo_id,
+                    create_model_card=create_model_card,
+                    train_datasets=train_datasets,
+                    safe_serialization=safe_serialization,
+                )
+                folder_url = api.upload_folder(
+                    repo_id=repo_id,
+                    folder_path=tmp_dir,
+                    commit_message=commit_message,
+                    commit_description=commit_description,
+                    revision=revision,
+                    create_pr=create_pr,
+                )
+
+        if create_pr:
+            return folder_url.pr_url
+        return folder_url.commit_url
+
+    def _text_length(self, text: list[int] | list[list[int]]) -> int:
+        """
+        Help function to get the length for the input text. Text can be either
+        a list of ints (which means a single text as input), or a tuple of list of ints
+        (representing several text inputs to the model).
+        """
+
+        if isinstance(text, dict):  # {key: value} case
+            return len(next(iter(text.values())))
+        elif not hasattr(text, "__len__"):  # Object has no len() method
+            return 1
+        elif len(text) == 0 or isinstance(text[0], int):  # Empty string or list of ints
+            return len(text)
+        else:
+            return sum([len(t) for t in text])  # Sum of length of individual strings
+
+    def evaluate(self, evaluator: SentenceEvaluator, output_path: str = None) -> dict[str, float] | float:
+        """
+        Evaluate the model based on an evaluator
+
+        Args:
+            evaluator (SentenceEvaluator): The evaluator used to evaluate the model.
+            output_path (str, optional): The path where the evaluator can write the results. Defaults to None.
+
+        Returns:
+            The evaluation results.
+        """
+        if output_path is not None:
+            os.makedirs(output_path, exist_ok=True)
+        return evaluator(self, output_path)
+
+    def _load_auto_model(
+        self,
+        model_name_or_path: str,
+        token: bool | str | None,
+        cache_folder: str | None,
+        revision: str | None = None,
+        trust_remote_code: bool = False,
+        local_files_only: bool = False,
+        model_kwargs: dict[str, Any] | None = None,
+        tokenizer_kwargs: dict[str, Any] | None = None,
+        config_kwargs: dict[str, Any] | None = None,
+    ) -> list[nn.Module]:
+        """
+        Creates a simple Transformer + Mean Pooling model and returns the modules
+
+        Args:
+            model_name_or_path (str): The name or path of the pre-trained model.
+            token (Optional[Union[bool, str]]): The token to use for the model.
+            cache_folder (Optional[str]): The folder to cache the model.
+            revision (Optional[str], optional): The revision of the model. Defaults to None.
+            trust_remote_code (bool, optional): Whether to trust remote code. Defaults to False.
+            local_files_only (bool, optional): Whether to use only local files. Defaults to False.
+            model_kwargs (Optional[Dict[str, Any]], optional): Additional keyword arguments for the model. Defaults to None.
+            tokenizer_kwargs (Optional[Dict[str, Any]], optional): Additional keyword arguments for the tokenizer. Defaults to None.
+            config_kwargs (Optional[Dict[str, Any]], optional): Additional keyword arguments for the config. Defaults to None.
+
+        Returns:
+            List[nn.Module]: A list containing the transformer model and the pooling model.
+        """
+        logger.warning(
+            f"No sentence-transformers model found with name {model_name_or_path}. Creating a new one with mean pooling."
+        )
+
+        shared_kwargs = {
+            "token": token,
+            "trust_remote_code": trust_remote_code,
+            "revision": revision,
+            "local_files_only": local_files_only,
+        }
+        model_kwargs = shared_kwargs if model_kwargs is None else {**shared_kwargs, **model_kwargs}
+        tokenizer_kwargs = shared_kwargs if tokenizer_kwargs is None else {**shared_kwargs, **tokenizer_kwargs}
+        config_kwargs = shared_kwargs if config_kwargs is None else {**shared_kwargs, **config_kwargs}
+
+        transformer_model = Transformer(
+            model_name_or_path,
+            cache_dir=cache_folder,
+            model_args=model_kwargs,
+            tokenizer_args=tokenizer_kwargs,
+            config_args=config_kwargs,
+            backend=self.backend,
+        )
+        pooling_model = Pooling(transformer_model.get_word_embedding_dimension(), "mean")
+        self.model_card_data.set_base_model(model_name_or_path, revision=revision)
+        return [transformer_model, pooling_model]
+
+    def _load_module_class_from_ref(
+        self,
+        class_ref: str,
+        model_name_or_path: str,
+        trust_remote_code: bool,
+        revision: str | None,
+        model_kwargs: dict[str, Any] | None,
+    ) -> nn.Module:
+        # If the class is from sentence_transformers, we can directly import it,
+        # otherwise, we try to import it dynamically, and if that fails, we fall back to the default import
+        if class_ref.startswith("sentence_transformers."):
+            return import_from_string(class_ref)
+
+        if trust_remote_code:
+            code_revision = model_kwargs.pop("code_revision", None) if model_kwargs else None
+            try:
+                return get_class_from_dynamic_module(
+                    class_ref,
+                    model_name_or_path,
+                    revision=revision,
+                    code_revision=code_revision,
+                )
+            except OSError:
+                # Ignore the error if the file does not exist, and fall back to the default import
+                pass
+
+        return import_from_string(class_ref)
+
+    def _load_sbert_model(
+        self,
+        model_name_or_path: str,
+        token: bool | str | None,
+        cache_folder: str | None,
+        revision: str | None = None,
+        trust_remote_code: bool = False,
+        local_files_only: bool = False,
+        model_kwargs: dict[str, Any] | None = None,
+        tokenizer_kwargs: dict[str, Any] | None = None,
+        config_kwargs: dict[str, Any] | None = None,
+    ) -> dict[str, nn.Module]:
+        """
+        Loads a full SentenceTransformer model using the modules.json file.
+
+        Args:
+            model_name_or_path (str): The name or path of the pre-trained model.
+            token (Optional[Union[bool, str]]): The token to use for the model.
+            cache_folder (Optional[str]): The folder to cache the model.
+            revision (Optional[str], optional): The revision of the model. Defaults to None.
+            trust_remote_code (bool, optional): Whether to trust remote code. Defaults to False.
+            local_files_only (bool, optional): Whether to use only local files. Defaults to False.
+            model_kwargs (Optional[Dict[str, Any]], optional): Additional keyword arguments for the model. Defaults to None.
+            tokenizer_kwargs (Optional[Dict[str, Any]], optional): Additional keyword arguments for the tokenizer. Defaults to None.
+            config_kwargs (Optional[Dict[str, Any]], optional): Additional keyword arguments for the config. Defaults to None.
+
+        Returns:
+            OrderedDict[str, nn.Module]: An ordered dictionary containing the modules of the model.
+        """
+        # Check if the config_sentence_transformers.json file exists (exists since v2 of the framework)
+        config_sentence_transformers_json_path = load_file_path(
+            model_name_or_path,
+            "config_sentence_transformers.json",
+            token=token,
+            cache_folder=cache_folder,
+            revision=revision,
+            local_files_only=local_files_only,
+        )
+        if config_sentence_transformers_json_path is not None:
+            with open(config_sentence_transformers_json_path) as fIn:
+                self._model_config = json.load(fIn)
+
+            if (
+                "__version__" in self._model_config
+                and "sentence_transformers" in self._model_config["__version__"]
+                and self._model_config["__version__"]["sentence_transformers"] > __version__
+            ):
+                logger.warning(
+                    "You try to use a model that was created with version {}, however, your version is {}. This might cause unexpected behavior or errors. In that case, try to update to the latest version.\n\n\n".format(
+                        self._model_config["__version__"]["sentence_transformers"], __version__
+                    )
+                )
+
+            # Set score functions & prompts if not already overridden by the __init__ calls
+            if self._similarity_fn_name is None:
+                self.similarity_fn_name = self._model_config.get("similarity_fn_name", None)
+            if not self.prompts:
+                self.prompts = self._model_config.get("prompts", {})
+            if not self.default_prompt_name:
+                self.default_prompt_name = self._model_config.get("default_prompt_name", None)
+
+        # Check if a readme exists
+        model_card_path = load_file_path(
+            model_name_or_path,
+            "README.md",
+            token=token,
+            cache_folder=cache_folder,
+            revision=revision,
+            local_files_only=local_files_only,
+        )
+        if model_card_path is not None:
+            try:
+                with open(model_card_path, encoding="utf8") as fIn:
+                    self._model_card_text = fIn.read()
+            except Exception:
+                pass
+
+        # Load the modules of sentence transformer
+        modules_json_path = load_file_path(
+            model_name_or_path,
+            "modules.json",
+            token=token,
+            cache_folder=cache_folder,
+            revision=revision,
+            local_files_only=local_files_only,
+        )
+        with open(modules_json_path) as fIn:
+            modules_config = json.load(fIn)
+
+        modules = OrderedDict()
+        module_kwargs = OrderedDict()
+        for module_config in modules_config:
+            class_ref = module_config["type"]
+            module_class = self._load_module_class_from_ref(
+                class_ref, model_name_or_path, trust_remote_code, revision, model_kwargs
+            )
+
+            # For Transformer, don't load the full directory, rely on `transformers` instead
+            # But, do load the config file first.
+            if module_config["path"] == "":
+                kwargs = {}
+                for config_name in [
+                    "sentence_bert_config.json",
+                    "sentence_roberta_config.json",
+                    "sentence_distilbert_config.json",
+                    "sentence_camembert_config.json",
+                    "sentence_albert_config.json",
+                    "sentence_xlm-roberta_config.json",
+                    "sentence_xlnet_config.json",
+                ]:
+                    config_path = load_file_path(
+                        model_name_or_path,
+                        config_name,
+                        token=token,
+                        cache_folder=cache_folder,
+                        revision=revision,
+                        local_files_only=local_files_only,
+                    )
+                    if config_path is not None:
+                        with open(config_path) as fIn:
+                            kwargs = json.load(fIn)
+                            # Don't allow configs to set trust_remote_code
+                            if "model_args" in kwargs and "trust_remote_code" in kwargs["model_args"]:
+                                kwargs["model_args"].pop("trust_remote_code")
+                            if "tokenizer_args" in kwargs and "trust_remote_code" in kwargs["tokenizer_args"]:
+                                kwargs["tokenizer_args"].pop("trust_remote_code")
+                            if "config_args" in kwargs and "trust_remote_code" in kwargs["config_args"]:
+                                kwargs["config_args"].pop("trust_remote_code")
+                        break
+
+                hub_kwargs = {
+                    "token": token,
+                    "trust_remote_code": trust_remote_code,
+                    "revision": revision,
+                    "local_files_only": local_files_only,
+                }
+                # 3rd priority: config file
+                if "model_args" not in kwargs:
+                    kwargs["model_args"] = {}
+                if "tokenizer_args" not in kwargs:
+                    kwargs["tokenizer_args"] = {}
+                if "config_args" not in kwargs:
+                    kwargs["config_args"] = {}
+
+                # 2nd priority: hub_kwargs
+                kwargs["model_args"].update(hub_kwargs)
+                kwargs["tokenizer_args"].update(hub_kwargs)
+                kwargs["config_args"].update(hub_kwargs)
+
+                # 1st priority: kwargs passed to SentenceTransformer
+                if model_kwargs:
+                    kwargs["model_args"].update(model_kwargs)
+                if tokenizer_kwargs:
+                    kwargs["tokenizer_args"].update(tokenizer_kwargs)
+                if config_kwargs:
+                    kwargs["config_args"].update(config_kwargs)
+
+                # Try to initialize the module with a lot of kwargs, but only if the module supports them
+                # Otherwise we fall back to the load method
+                try:
+                    module = module_class(model_name_or_path, cache_dir=cache_folder, backend=self.backend, **kwargs)
+                except TypeError:
+                    module = module_class.load(model_name_or_path)
+            else:
+                # Normalize does not require any files to be loaded
+                if module_class == Normalize:
+                    module_path = None
+                else:
+                    module_path = load_dir_path(
+                        model_name_or_path,
+                        module_config["path"],
+                        token=token,
+                        cache_folder=cache_folder,
+                        revision=revision,
+                        local_files_only=local_files_only,
+                    )
+                module = module_class.load(module_path)
+
+            modules[module_config["name"]] = module
+            module_kwargs[module_config["name"]] = module_config.get("kwargs", [])
+
+        if revision is None:
+            path_parts = Path(modules_json_path)
+            if len(path_parts.parts) >= 2:
+                revision_path_part = Path(modules_json_path).parts[-2]
+                if len(revision_path_part) == 40:
+                    revision = revision_path_part
+        self.model_card_data.set_base_model(model_name_or_path, revision=revision)
+        return modules, module_kwargs
+
+    @staticmethod
+    def load(input_path) -> SentenceTransformer:
+        return SentenceTransformer(input_path)
+
+    @property
+    def device(self) -> device:
+        """
+        Get torch.device from module, assuming that the whole module has one device.
+        In case there are no PyTorch parameters, fall back to CPU.
+        """
+        if isinstance(self[0], Transformer):
+            return self[0].auto_model.device
+
+        try:
+            return next(self.parameters()).device
+        except StopIteration:
+            # For nn.DataParallel compatibility in PyTorch 1.5
+
+            def find_tensor_attributes(module: nn.Module) -> list[tuple[str, Tensor]]:
+                tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
+                return tuples
+
+            gen = self._named_members(get_members_fn=find_tensor_attributes)
+            try:
+                first_tuple = next(gen)
+                return first_tuple[1].device
+            except StopIteration:
+                return torch.device("cpu")
+
+    @property
+    def tokenizer(self) -> Any:
+        """
+        Property to get the tokenizer that is used by this model
+        """
+        return self._first_module().tokenizer
+
+    @tokenizer.setter
+    def tokenizer(self, value) -> None:
+        """
+        Property to set the tokenizer that should be used by this model
+        """
+        self._first_module().tokenizer = value
+
+    @property
+    def max_seq_length(self) -> int:
+        """
+        Returns the maximal input sequence length for the model. Longer inputs will be truncated.
+
+        Returns:
+            int: The maximal input sequence length.
+
+        Example:
+            ::
+
+                from sentence_transformers import SentenceTransformer
+
+                model = SentenceTransformer("all-mpnet-base-v2")
+                print(model.max_seq_length)
+                # => 384
+        """
+        return self._first_module().max_seq_length
+
+    @max_seq_length.setter
+    def max_seq_length(self, value) -> None:
+        """
+        Property to set the maximal input sequence length for the model. Longer inputs will be truncated.
+        """
+        self._first_module().max_seq_length = value
+
+    @property
+    def _target_device(self) -> torch.device:
+        logger.warning(
+            "`SentenceTransformer._target_device` has been deprecated, please use `SentenceTransformer.device` instead.",
+        )
+        return self.device
+
+    @_target_device.setter
+    def _target_device(self, device: int | str | torch.device | None = None) -> None:
+        self.to(device)
+
+    @property
+    def _no_split_modules(self) -> list[str]:
+        try:
+            return self._first_module()._no_split_modules
+        except AttributeError:
+            return []
+
+    @property
+    def _keys_to_ignore_on_save(self) -> list[str]:
+        try:
+            return self._first_module()._keys_to_ignore_on_save
+        except AttributeError:
+            return []
+
+    def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None) -> None:
+        # Propagate the gradient checkpointing to the transformer model
+        for module in self:
+            if isinstance(module, Transformer):
+                return module.auto_model.gradient_checkpointing_enable(gradient_checkpointing_kwargs)
+
+

Loads or creates a SentenceTransformer model that can be used to map sentences / text to embeddings.

+

Args

+
+
model_name_or_path : str, optional
+
If it is a filepath on disc, it loads the model from that path. If it is not a path, +it first tries to download a pre-trained SentenceTransformer model. If that fails, tries to construct a model +from the Hugging Face Hub with that name.
+
modules : Iterable[nn.Module], optional
+
A list of torch Modules that should be called sequentially, can be used to create custom +SentenceTransformer models from scratch.
+
device : str, optional
+
Device (like "cuda", "cpu", "mps", "npu") that should be used for computation. If None, checks if a GPU +can be used.
+
prompts : Dict[str, str], optional
+
A dictionary with prompts for the model. The key is the prompt name, the value is the prompt text. +The prompt text will be prepended before any text to encode. For example: +{"query": "query: ", "passage": "passage: "} or {"clustering": "Identify the main category based on the +titles in "}.
+
default_prompt_name : str, optional
+
The name of the prompt that should be used by default. If not set, +no prompt will be applied.
+
similarity_fn_name : str or SimilarityFunction, optional
+
The name of the similarity function to use. Valid options are "cosine", "dot", +"euclidean", and "manhattan". If not set, it is automatically set to "cosine" if similarity or +similarity_pairwise are called while model.similarity_fn_name is still None.
+
cache_folder : str, optional
+
Path to store models. Can also be set by the SENTENCE_TRANSFORMERS_HOME environment variable.
+
trust_remote_code : bool, optional
+
Whether or not to allow for custom models defined on the Hub in their own modeling files. +This option should only be set to True for repositories you trust and in which you have read the code, as it +will execute code present on the Hub on your local machine.
+
revision : str, optional
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, +for a stored model on Hugging Face.
+
local_files_only : bool, optional
+
Whether or not to only look at local files (i.e., do not try to download the model).
+
token : bool or str, optional
+
Hugging Face authentication token to download private models.
+
use_auth_token : bool or str, optional
+
Deprecated argument. Please use token instead.
+
truncate_dim : int, optional
+
The dimension to truncate sentence embeddings to. None does no truncation. Truncation is +only applicable during inference when :meth:SentenceTransformer.encode() is called.
+
model_kwargs : Dict[str, Any], optional
+
+

Additional model configuration parameters to be passed to the Hugging Face Transformers model. +Particularly useful options are:

+
    +
  • torch_dtype: Override the default torch.dtype and load the model under a specific dtype. +The different options are:
    1. <code>torch.float16</code>, <code>torch.bfloat16</code> or <code>torch.float</code>: load in a specified
    +<code>dtype</code>, ignoring the model's <code>config.torch\_dtype</code> if one exists. If not specified - the model will
    +get loaded in <code>torch.float</code> (fp32).
    +
    +2. ``"auto"`` - A <code>torch\_dtype</code> entry in the <code>config.json</code> file of the model will be
    +attempted to be used. If this entry isn't found then next check the <code>dtype</code> of the first weight in
    +the checkpoint that's of a floating point type and use that as <code>dtype</code>. This will load the model
    +using the <code>dtype</code> it was saved in at the end of the training. It can't be used as an indicator of how
    +the model was trained. Since it could be trained in one of half precision dtypes, but saved in fp32.
    +- <code>attn\_implementation</code>: The attention implementation to use in the model (if relevant). Can be any of
    +  `"eager"` (manual implementation of the attention), `"sdpa"` (using `F.scaled_dot_product_attention
    +  <https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html>`_),
    +  or `"flash_attention_2"` (using `Dao-AILab/flash-attention <https://github.com/Dao-AILab/flash-attention>`_).
    +  By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"`
    +  implementation.
    +- <code>provider</code>: If backend is "onnx", this is the provider to use for inference, for example "CPUExecutionProvider",
    +  "CUDAExecutionProvider", etc. See <https://onnxruntime.ai/docs/execution-providers/> for all ONNX execution providers.
    +- <code>file\_name</code>: If backend is "onnx" or "openvino", this is the file name to load, useful for loading optimized
    +  or quantized ONNX or OpenVINO models.
    +- <code>export</code>: If backend is "onnx" or "openvino", then this is a boolean flag specifying whether this model should
    +  be exported to the backend. If not specified, the model will be exported only if the model repository or directory
    +  does not already contain an exported model.
    +
    +
  • +
+

See the PreTrainedModel.from_pretrained +<https://huggingface.co/docs/transformers/en/main_classes/model#transformers.PreTrainedModel.from_pretrained>_ +documentation for more details.

+
+
tokenizer_kwargs : Dict[str, Any], optional
+
Additional tokenizer configuration parameters to be passed to the Hugging Face Transformers tokenizer. +See the AutoTokenizer.from_pretrained +<https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoTokenizer.from_pretrained>_ +documentation for more details.
+
config_kwargs : Dict[str, Any], optional
+
Additional model configuration parameters to be passed to the Hugging Face Transformers config. +See the AutoConfig.from_pretrained +<https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoConfig.from_pretrained>_ +documentation for more details.
+
model_card_data (:class:~sentence_transformers.model_card.SentenceTransformerModelCardData, optional): A model
+
card data object that contains information about the model. This is used to generate a model card when saving
+
the model. If not set, a default model card data object is created.
+
backend : str
+
The backend to use for inference. Can be one of "torch" (default), "onnx", or "openvino". +See https://sbert.net/docs/sentence_transformer/usage/efficiency.html for benchmarking information +on the different backends.
+
+

Example

+

::

+
from sentence_transformers import SentenceTransformer
+
+# Load a pre-trained SentenceTransformer model
+model = SentenceTransformer('all-mpnet-base-v2')
+
+# Encode some texts
+sentences = [
+    "The weather is lovely today.",
+    "It's so sunny outside!",
+    "He drove to the stadium.",
+]
+embeddings = model.encode(sentences)
+print(embeddings.shape)
+# (3, 768)
+
+# Get the similarity scores between all sentences
+similarities = model.similarity(embeddings, embeddings)
+print(similarities)
+# tensor([[1.0000, 0.6817, 0.0492],
+#         [0.6817, 1.0000, 0.0421],
+#         [0.0492, 0.0421, 1.0000]])
+
+

Initialize internal Module state, shared by both nn.Module and ScriptModule.

+

Ancestors

+
    +
  • torch.nn.modules.container.Sequential
  • +
  • torch.nn.modules.module.Module
  • +
  • sentence_transformers.fit_mixin.FitMixin
  • +
  • sentence_transformers.peft_mixin.PeftAdapterMixin
  • +
+

Static methods

+
+
+def load(input_path) ‑> sentence_transformers.SentenceTransformer.SentenceTransformer +
+
+
+ +Expand source code + +
@staticmethod
+def load(input_path) -> SentenceTransformer:
+    return SentenceTransformer(input_path)
+
+
+
+
+def stop_multi_process_pool(pool: "dict[Literal['input', 'output', 'processes'], Any]") ‑> None +
+
+
+ +Expand source code + +
@staticmethod
+def stop_multi_process_pool(pool: dict[Literal["input", "output", "processes"], Any]) -> None:
+    """
+    Stops all processes started with start_multi_process_pool.
+
+    Args:
+        pool (Dict[str, object]): A dictionary containing the input queue, output queue, and process list.
+
+    Returns:
+        None
+    """
+    for p in pool["processes"]:
+        p.terminate()
+
+    for p in pool["processes"]:
+        p.join()
+        p.close()
+
+    pool["input"].close()
+    pool["output"].close()
+
+

Stops all processes started with start_multi_process_pool.

+

Args

+
+
pool : Dict[str, object]
+
A dictionary containing the input queue, output queue, and process list.
+
+

Returns

+

None

+
+
+

Instance variables

+
+
prop device : device
+
+
+ +Expand source code + +
@property
+def device(self) -> device:
+    """
+    Get torch.device from module, assuming that the whole module has one device.
+    In case there are no PyTorch parameters, fall back to CPU.
+    """
+    if isinstance(self[0], Transformer):
+        return self[0].auto_model.device
+
+    try:
+        return next(self.parameters()).device
+    except StopIteration:
+        # For nn.DataParallel compatibility in PyTorch 1.5
+
+        def find_tensor_attributes(module: nn.Module) -> list[tuple[str, Tensor]]:
+            tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
+            return tuples
+
+        gen = self._named_members(get_members_fn=find_tensor_attributes)
+        try:
+            first_tuple = next(gen)
+            return first_tuple[1].device
+        except StopIteration:
+            return torch.device("cpu")
+
+

Get torch.device from module, assuming that the whole module has one device. +In case there are no PyTorch parameters, fall back to CPU.

+
+
prop max_seq_length : int
+
+
+ +Expand source code + +
@property
+def max_seq_length(self) -> int:
+    """
+    Returns the maximal input sequence length for the model. Longer inputs will be truncated.
+
+    Returns:
+        int: The maximal input sequence length.
+
+    Example:
+        ::
+
+            from sentence_transformers import SentenceTransformer
+
+            model = SentenceTransformer("all-mpnet-base-v2")
+            print(model.max_seq_length)
+            # => 384
+    """
+    return self._first_module().max_seq_length
+
+

Returns the maximal input sequence length for the model. Longer inputs will be truncated.

+

Returns

+
+
int
+
The maximal input sequence length.
+
+

Example

+

::

+
from sentence_transformers import SentenceTransformer
+
+model = SentenceTransformer("all-mpnet-base-v2")
+print(model.max_seq_length)
+# => 384
+
+
+
prop similarity : Callable[[Tensor | ndarray, Tensor | ndarray], Tensor]
+
+
+ +Expand source code + +
@property
+def similarity(self) -> Callable[[Tensor | ndarray, Tensor | ndarray], Tensor]:
+    """
+    Compute the similarity between two collections of embeddings. The output will be a matrix with the similarity
+    scores between all embeddings from the first parameter and all embeddings from the second parameter. This
+    differs from `similarity_pairwise` which computes the similarity between each pair of embeddings.
+
+    Args:
+        embeddings1 (Union[Tensor, ndarray]): [num_embeddings_1, embedding_dim] or [embedding_dim]-shaped numpy array or torch tensor.
+        embeddings2 (Union[Tensor, ndarray]): [num_embeddings_2, embedding_dim] or [embedding_dim]-shaped numpy array or torch tensor.
+
+    Returns:
+        Tensor: A [num_embeddings_1, num_embeddings_2]-shaped torch tensor with similarity scores.
+
+    Example:
+        ::
+
+            >>> model = SentenceTransformer("all-mpnet-base-v2")
+            >>> sentences = [
+            ...     "The weather is so nice!",
+            ...     "It's so sunny outside.",
+            ...     "He's driving to the movie theater.",
+            ...     "She's going to the cinema.",
+            ... ]
+            >>> embeddings = model.encode(sentences, normalize_embeddings=True)
+            >>> model.similarity(embeddings, embeddings)
+            tensor([[1.0000, 0.7235, 0.0290, 0.1309],
+                    [0.7235, 1.0000, 0.0613, 0.1129],
+                    [0.0290, 0.0613, 1.0000, 0.5027],
+                    [0.1309, 0.1129, 0.5027, 1.0000]])
+            >>> model.similarity_fn_name
+            "cosine"
+            >>> model.similarity_fn_name = "euclidean"
+            >>> model.similarity(embeddings, embeddings)
+            tensor([[-0.0000, -0.7437, -1.3935, -1.3184],
+                    [-0.7437, -0.0000, -1.3702, -1.3320],
+                    [-1.3935, -1.3702, -0.0000, -0.9973],
+                    [-1.3184, -1.3320, -0.9973, -0.0000]])
+    """
+    if self.similarity_fn_name is None:
+        self.similarity_fn_name = SimilarityFunction.COSINE
+    return self._similarity
+
+

Compute the similarity between two collections of embeddings. The output will be a matrix with the similarity +scores between all embeddings from the first parameter and all embeddings from the second parameter. This +differs from similarity_pairwise which computes the similarity between each pair of embeddings.

+

Args

+
+
embeddings1 : Union[Tensor, ndarray]
+
[num_embeddings_1, embedding_dim] or [embedding_dim]-shaped numpy array or torch tensor.
+
embeddings2 : Union[Tensor, ndarray]
+
[num_embeddings_2, embedding_dim] or [embedding_dim]-shaped numpy array or torch tensor.
+
+

Returns

+
+
Tensor
+
A [num_embeddings_1, num_embeddings_2]-shaped torch tensor with similarity scores.
+
+

Example

+

::

+
>>> model = SentenceTransformer("all-mpnet-base-v2")
+>>> sentences = [
+...     "The weather is so nice!",
+...     "It's so sunny outside.",
+...     "He's driving to the movie theater.",
+...     "She's going to the cinema.",
+... ]
+>>> embeddings = model.encode(sentences, normalize_embeddings=True)
+>>> model.similarity(embeddings, embeddings)
+tensor([[1.0000, 0.7235, 0.0290, 0.1309],
+        [0.7235, 1.0000, 0.0613, 0.1129],
+        [0.0290, 0.0613, 1.0000, 0.5027],
+        [0.1309, 0.1129, 0.5027, 1.0000]])
+>>> model.similarity_fn_name
+"cosine"
+>>> model.similarity_fn_name = "euclidean"
+>>> model.similarity(embeddings, embeddings)
+tensor([[-0.0000, -0.7437, -1.3935, -1.3184],
+        [-0.7437, -0.0000, -1.3702, -1.3320],
+        [-1.3935, -1.3702, -0.0000, -0.9973],
+        [-1.3184, -1.3320, -0.9973, -0.0000]])
+
+
+
prop similarity_fn_name : Literal['cosine', 'dot', 'euclidean', 'manhattan']
+
+
+ +Expand source code + +
@property
+def similarity_fn_name(self) -> Literal["cosine", "dot", "euclidean", "manhattan"]:
+    """Return the name of the similarity function used by :meth:`SentenceTransformer.similarity` and :meth:`SentenceTransformer.similarity_pairwise`.
+
+    Returns:
+        Optional[str]: The name of the similarity function. Can be None if not set, in which case it will
+            default to "cosine" when first called.
+
+    Example:
+        >>> model = SentenceTransformer("multi-qa-mpnet-base-dot-v1")
+        >>> model.similarity_fn_name
+        'dot'
+    """
+    if self._similarity_fn_name is None:
+        self.similarity_fn_name = SimilarityFunction.COSINE
+    return self._similarity_fn_name
+
+

Return the name of the similarity function used by :meth:SentenceTransformer.similarity and :meth:SentenceTransformer.similarity_pairwise.

+

Returns

+
+
Optional[str]
+
The name of the similarity function. Can be None if not set, in which case it will +default to "cosine" when first called.
+
+

Example

+
>>> model = SentenceTransformer("multi-qa-mpnet-base-dot-v1")
+>>> model.similarity_fn_name
+'dot'
+
+
+
prop similarity_pairwise : Callable[[Tensor | ndarray, Tensor | ndarray], Tensor]
+
+
+ +Expand source code + +
@property
+def similarity_pairwise(self) -> Callable[[Tensor | ndarray, Tensor | ndarray], Tensor]:
+    """
+    Compute the similarity between two collections of embeddings. The output will be a vector with the similarity
+    scores between each pair of embeddings.
+
+    Args:
+        embeddings1 (Union[Tensor, ndarray]): [num_embeddings, embedding_dim] or [embedding_dim]-shaped numpy array or torch tensor.
+        embeddings2 (Union[Tensor, ndarray]): [num_embeddings, embedding_dim] or [embedding_dim]-shaped numpy array or torch tensor.
+
+    Returns:
+        Tensor: A [num_embeddings]-shaped torch tensor with pairwise similarity scores.
+
+    Example:
+        ::
+
+            >>> model = SentenceTransformer("all-mpnet-base-v2")
+            >>> sentences = [
+            ...     "The weather is so nice!",
+            ...     "It's so sunny outside.",
+            ...     "He's driving to the movie theater.",
+            ...     "She's going to the cinema.",
+            ... ]
+            >>> embeddings = model.encode(sentences, normalize_embeddings=True)
+            >>> model.similarity_pairwise(embeddings[::2], embeddings[1::2])
+            tensor([0.7235, 0.5027])
+            >>> model.similarity_fn_name
+            "cosine"
+            >>> model.similarity_fn_name = "euclidean"
+            >>> model.similarity_pairwise(embeddings[::2], embeddings[1::2])
+            tensor([-0.7437, -0.9973])
+    """
+    if self.similarity_fn_name is None:
+        self.similarity_fn_name = SimilarityFunction.COSINE
+    return self._similarity_pairwise
+
+

Compute the similarity between two collections of embeddings. The output will be a vector with the similarity +scores between each pair of embeddings.

+

Args

+
+
embeddings1 : Union[Tensor, ndarray]
+
[num_embeddings, embedding_dim] or [embedding_dim]-shaped numpy array or torch tensor.
+
embeddings2 : Union[Tensor, ndarray]
+
[num_embeddings, embedding_dim] or [embedding_dim]-shaped numpy array or torch tensor.
+
+

Returns

+
+
Tensor
+
A [num_embeddings]-shaped torch tensor with pairwise similarity scores.
+
+

Example

+

::

+
>>> model = SentenceTransformer("all-mpnet-base-v2")
+>>> sentences = [
+...     "The weather is so nice!",
+...     "It's so sunny outside.",
+...     "He's driving to the movie theater.",
+...     "She's going to the cinema.",
+... ]
+>>> embeddings = model.encode(sentences, normalize_embeddings=True)
+>>> model.similarity_pairwise(embeddings[::2], embeddings[1::2])
+tensor([0.7235, 0.5027])
+>>> model.similarity_fn_name
+"cosine"
+>>> model.similarity_fn_name = "euclidean"
+>>> model.similarity_pairwise(embeddings[::2], embeddings[1::2])
+tensor([-0.7437, -0.9973])
+
+
+
prop tokenizer : Any
+
+
+ +Expand source code + +
@property
+def tokenizer(self) -> Any:
+    """
+    Property to get the tokenizer that is used by this model
+    """
+    return self._first_module().tokenizer
+
+

Property to get the tokenizer that is used by this model

+
+
+

Methods

+
+
+def encode(self,
sentences: str | list[str],
prompt_name: str | None = None,
prompt: str | None = None,
batch_size: int = 32,
show_progress_bar: bool | None = None,
output_value: "Literal['sentence_embedding', 'token_embeddings'] | None" = 'sentence_embedding',
precision: "Literal['float32', 'int8', 'uint8', 'binary', 'ubinary']" = 'float32',
convert_to_numpy: bool = True,
convert_to_tensor: bool = False,
device: str = None,
normalize_embeddings: bool = False,
**kwargs) ‑> list[torch.Tensor] | numpy.ndarray | torch.Tensor
+
+
+
+ +Expand source code + +
def encode(
+    self,
+    sentences: str | list[str],
+    prompt_name: str | None = None,
+    prompt: str | None = None,
+    batch_size: int = 32,
+    show_progress_bar: bool | None = None,
+    output_value: Literal["sentence_embedding", "token_embeddings"] | None = "sentence_embedding",
+    precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32",
+    convert_to_numpy: bool = True,
+    convert_to_tensor: bool = False,
+    device: str = None,
+    normalize_embeddings: bool = False,
+    **kwargs,
+) -> list[Tensor] | np.ndarray | Tensor:
+    """
+    Computes sentence embeddings.
+
+    Args:
+        sentences (Union[str, List[str]]): The sentences to embed.
+        prompt_name (Optional[str], optional): The name of the prompt to use for encoding. Must be a key in the `prompts` dictionary,
+            which is either set in the constructor or loaded from the model configuration. For example if
+            ``prompt_name`` is "query" and the ``prompts`` is {"query": "query: ", ...}, then the sentence "What
+            is the capital of France?" will be encoded as "query: What is the capital of France?" because the sentence
+            is appended to the prompt. If ``prompt`` is also set, this argument is ignored. Defaults to None.
+        prompt (Optional[str], optional): The prompt to use for encoding. For example, if the prompt is "query: ", then the
+            sentence "What is the capital of France?" will be encoded as "query: What is the capital of France?"
+            because the sentence is appended to the prompt. If ``prompt`` is set, ``prompt_name`` is ignored. Defaults to None.
+        batch_size (int, optional): The batch size used for the computation. Defaults to 32.
+        show_progress_bar (bool, optional): Whether to output a progress bar when encode sentences. Defaults to None.
+        output_value (Optional[Literal["sentence_embedding", "token_embeddings"]], optional): The type of embeddings to return:
+            "sentence_embedding" to get sentence embeddings, "token_embeddings" to get wordpiece token embeddings, and `None`,
+            to get all output values. Defaults to "sentence_embedding".
+        precision (Literal["float32", "int8", "uint8", "binary", "ubinary"], optional): The precision to use for the embeddings.
+            Can be "float32", "int8", "uint8", "binary", or "ubinary". All non-float32 precisions are quantized embeddings.
+            Quantized embeddings are smaller in size and faster to compute, but may have a lower accuracy. They are useful for
+            reducing the size of the embeddings of a corpus for semantic search, among other tasks. Defaults to "float32".
+        convert_to_numpy (bool, optional): Whether the output should be a list of numpy vectors. If False, it is a list of PyTorch tensors.
+            Defaults to True.
+        convert_to_tensor (bool, optional): Whether the output should be one large tensor. Overwrites `convert_to_numpy`.
+            Defaults to False.
+        device (str, optional): Which :class:`torch.device` to use for the computation. Defaults to None.
+        normalize_embeddings (bool, optional): Whether to normalize returned vectors to have length 1. In that case,
+            the faster dot-product (util.dot_score) instead of cosine similarity can be used. Defaults to False.
+
+    Returns:
+        Union[List[Tensor], ndarray, Tensor]: By default, a 2d numpy array with shape [num_inputs, output_dimension] is returned.
+        If only one string input is provided, then the output is a 1d array with shape [output_dimension]. If ``convert_to_tensor``,
+        a torch Tensor is returned instead. If ``self.truncate_dim <= output_dimension`` then output_dimension is ``self.truncate_dim``.
+
+    Example:
+        ::
+
+            from sentence_transformers import SentenceTransformer
+
+            # Load a pre-trained SentenceTransformer model
+            model = SentenceTransformer('all-mpnet-base-v2')
+
+            # Encode some texts
+            sentences = [
+                "The weather is lovely today.",
+                "It's so sunny outside!",
+                "He drove to the stadium.",
+            ]
+            embeddings = model.encode(sentences)
+            print(embeddings.shape)
+            # (3, 768)
+    """
+    if self.device.type == "hpu" and not self.is_hpu_graph_enabled:
+        import habana_frameworks.torch as ht
+
+        ht.hpu.wrap_in_hpu_graph(self, disable_tensor_cache=True)
+        self.is_hpu_graph_enabled = True
+
+    self.eval()
+    if show_progress_bar is None:
+        show_progress_bar = logger.getEffectiveLevel() in (logging.INFO, logging.DEBUG)
+
+    if convert_to_tensor:
+        convert_to_numpy = False
+
+    if output_value != "sentence_embedding":
+        convert_to_tensor = False
+        convert_to_numpy = False
+
+    input_was_string = False
+    if isinstance(sentences, str) or not hasattr(
+        sentences, "__len__"
+    ):  # Cast an individual sentence to a list with length 1
+        sentences = [sentences]
+        input_was_string = True
+
+    if prompt is None:
+        if prompt_name is not None:
+            try:
+                prompt = self.prompts[prompt_name]
+            except KeyError:
+                raise ValueError(
+                    f"Prompt name '{prompt_name}' not found in the configured prompts dictionary with keys {list(self.prompts.keys())!r}."
+                )
+        elif self.default_prompt_name is not None:
+            prompt = self.prompts.get(self.default_prompt_name, None)
+    else:
+        if prompt_name is not None:
+            logger.warning(
+                "Encode with either a `prompt`, a `prompt_name`, or neither, but not both. "
+                "Ignoring the `prompt_name` in favor of `prompt`."
+            )
+
+    extra_features = {}
+    if prompt is not None:
+        sentences = [prompt + sentence for sentence in sentences]
+
+        # Some models (e.g. INSTRUCTOR, GRIT) require removing the prompt before pooling
+        # Tracking the prompt length allow us to remove the prompt during pooling
+        tokenized_prompt = self.tokenize([prompt])
+        if "input_ids" in tokenized_prompt:
+            extra_features["prompt_length"] = tokenized_prompt["input_ids"].shape[-1] - 1
+
+    if device is None:
+        device = self.device
+
+    self.to(device)
+
+    all_embeddings = []
+    length_sorted_idx = np.argsort([-self._text_length(sen) for sen in sentences])
+    sentences_sorted = [sentences[idx] for idx in length_sorted_idx]
+
+    for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar):
+        sentences_batch = sentences_sorted[start_index : start_index + batch_size]
+        features = self.tokenize(sentences_batch)
+        if self.device.type == "hpu":
+            if "input_ids" in features:
+                curr_tokenize_len = features["input_ids"].shape
+                additional_pad_len = 2 ** math.ceil(math.log2(curr_tokenize_len[1])) - curr_tokenize_len[1]
+                features["input_ids"] = torch.cat(
+                    (
+                        features["input_ids"],
+                        torch.ones((curr_tokenize_len[0], additional_pad_len), dtype=torch.int8),
+                    ),
+                    -1,
+                )
+                features["attention_mask"] = torch.cat(
+                    (
+                        features["attention_mask"],
+                        torch.zeros((curr_tokenize_len[0], additional_pad_len), dtype=torch.int8),
+                    ),
+                    -1,
+                )
+                if "token_type_ids" in features:
+                    features["token_type_ids"] = torch.cat(
+                        (
+                            features["token_type_ids"],
+                            torch.zeros((curr_tokenize_len[0], additional_pad_len), dtype=torch.int8),
+                        ),
+                        -1,
+                    )
+
+        features = batch_to_device(features, device)
+        features.update(extra_features)
+
+        with torch.no_grad():
+            out_features = self.forward(features, **kwargs)
+            if self.device.type == "hpu":
+                out_features = copy.deepcopy(out_features)
+
+            out_features["sentence_embedding"] = truncate_embeddings(
+                out_features["sentence_embedding"], self.truncate_dim
+            )
+
+            if output_value == "token_embeddings":
+                embeddings = []
+                for token_emb, attention in zip(out_features[output_value], out_features["attention_mask"]):
+                    last_mask_id = len(attention) - 1
+                    while last_mask_id > 0 and attention[last_mask_id].item() == 0:
+                        last_mask_id -= 1
+
+                    embeddings.append(token_emb[0 : last_mask_id + 1])
+            elif output_value is None:  # Return all outputs
+                embeddings = []
+                for sent_idx in range(len(out_features["sentence_embedding"])):
+                    row = {name: out_features[name][sent_idx] for name in out_features}
+                    embeddings.append(row)
+            else:  # Sentence embeddings
+                embeddings = out_features[output_value]
+                embeddings = embeddings.detach()
+                if normalize_embeddings:
+                    embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
+
+                # fixes for #522 and #487 to avoid oom problems on gpu with large datasets
+                if convert_to_numpy:
+                    embeddings = embeddings.cpu()
+
+            all_embeddings.extend(embeddings)
+
+    all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)]
+
+    if precision and precision != "float32":
+        all_embeddings = quantize_embeddings(all_embeddings, precision=precision)
+
+    if convert_to_tensor:
+        if len(all_embeddings):
+            if isinstance(all_embeddings, np.ndarray):
+                all_embeddings = torch.from_numpy(all_embeddings)
+            else:
+                all_embeddings = torch.stack(all_embeddings)
+        else:
+            all_embeddings = torch.Tensor()
+    elif convert_to_numpy:
+        if not isinstance(all_embeddings, np.ndarray):
+            if all_embeddings and all_embeddings[0].dtype == torch.bfloat16:
+                all_embeddings = np.asarray([emb.float().numpy() for emb in all_embeddings])
+            else:
+                all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
+    elif isinstance(all_embeddings, np.ndarray):
+        all_embeddings = [torch.from_numpy(embedding) for embedding in all_embeddings]
+
+    if input_was_string:
+        all_embeddings = all_embeddings[0]
+
+    return all_embeddings
+
+

Computes sentence embeddings.

+

Args

+
+
sentences : Union[str, List[str]]
+
The sentences to embed.
+
prompt_name : Optional[str], optional
+
The name of the prompt to use for encoding. Must be a key in the prompts dictionary, +which is either set in the constructor or loaded from the model configuration. For example if +prompt_name is "query" and the prompts is {"query": "query: ", …}, then the sentence "What +is the capital of France?" will be encoded as "query: What is the capital of France?" because the sentence +is appended to the prompt. If prompt is also set, this argument is ignored. Defaults to None.
+
prompt : Optional[str], optional
+
The prompt to use for encoding. For example, if the prompt is "query: ", then the +sentence "What is the capital of France?" will be encoded as "query: What is the capital of France?" +because the sentence is appended to the prompt. If prompt is set, prompt_name is ignored. Defaults to None.
+
batch_size : int, optional
+
The batch size used for the computation. Defaults to 32.
+
show_progress_bar : bool, optional
+
Whether to output a progress bar when encode sentences. Defaults to None.
+
output_value (Optional[Literal["sentence_embedding", "token_embeddings"]], optional): The type of embeddings to return:
+
"sentence_embedding" to get sentence embeddings, "token_embeddings" to get wordpiece token embeddings, and None,
+
to get all output values. Defaults to "sentence_embedding".
+
precision (Literal["float32", "int8", "uint8", "binary", "ubinary"], optional): The precision to use for the embeddings.
+
Can be "float32", "int8", "uint8", "binary", or "ubinary". All non-float32 precisions are quantized embeddings.
+
Quantized embeddings are smaller in size and faster to compute, but may have a lower accuracy. They are useful for
+
reducing the size of the embeddings of a corpus for semantic search, among other tasks. Defaults to "float32".
+
convert_to_numpy : bool, optional
+
Whether the output should be a list of numpy vectors. If False, it is a list of PyTorch tensors. +Defaults to True.
+
convert_to_tensor : bool, optional
+
Whether the output should be one large tensor. Overwrites convert_to_numpy. +Defaults to False.
+
device : str, optional
+
Which :class:torch.device to use for the computation. Defaults to None.
+
normalize_embeddings : bool, optional
+
Whether to normalize returned vectors to have length 1. In that case, +the faster dot-product (util.dot_score) instead of cosine similarity can be used. Defaults to False.
+
+

Returns

+
+
Union[List[Tensor], ndarray, Tensor]
+
By default, a 2d numpy array with shape [num_inputs, output_dimension] is returned.
+
+

If only one string input is provided, then the output is a 1d array with shape [output_dimension]. If convert_to_tensor, +a torch Tensor is returned instead. If self.truncate_dim <= output_dimension then output_dimension is self.truncate_dim.

+

Example

+

::

+
from sentence_transformers import SentenceTransformer
+
+# Load a pre-trained SentenceTransformer model
+model = SentenceTransformer('all-mpnet-base-v2')
+
+# Encode some texts
+sentences = [
+    "The weather is lovely today.",
+    "It's so sunny outside!",
+    "He drove to the stadium.",
+]
+embeddings = model.encode(sentences)
+print(embeddings.shape)
+# (3, 768)
+
+
+
+def encode_multi_process(self,
sentences: list[str],
pool: "dict[Literal['input', 'output', 'processes'], Any]",
prompt_name: str | None = None,
prompt: str | None = None,
batch_size: int = 32,
chunk_size: int = None,
show_progress_bar: bool | None = None,
precision: "Literal['float32', 'int8', 'uint8', 'binary', 'ubinary']" = 'float32',
normalize_embeddings: bool = False) ‑> numpy.ndarray
+
+
+
+ +Expand source code + +
def encode_multi_process(
+    self,
+    sentences: list[str],
+    pool: dict[Literal["input", "output", "processes"], Any],
+    prompt_name: str | None = None,
+    prompt: str | None = None,
+    batch_size: int = 32,
+    chunk_size: int = None,
+    show_progress_bar: bool | None = None,
+    precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32",
+    normalize_embeddings: bool = False,
+) -> np.ndarray:
+    """
+    Encodes a list of sentences using multiple processes and GPUs via
+    :meth:`SentenceTransformer.encode <sentence_transformers.SentenceTransformer.encode>`.
+    The sentences are chunked into smaller packages and sent to individual processes, which encode them on different
+    GPUs or CPUs. This method is only suitable for encoding large sets of sentences.
+
+    Args:
+        sentences (List[str]): List of sentences to encode.
+        pool (Dict[Literal["input", "output", "processes"], Any]): A pool of workers started with
+            :meth:`SentenceTransformer.start_multi_process_pool <sentence_transformers.SentenceTransformer.start_multi_process_pool>`.
+        prompt_name (Optional[str], optional): The name of the prompt to use for encoding. Must be a key in the `prompts` dictionary,
+            which is either set in the constructor or loaded from the model configuration. For example if
+            ``prompt_name`` is "query" and the ``prompts`` is {"query": "query: ", ...}, then the sentence "What
+            is the capital of France?" will be encoded as "query: What is the capital of France?" because the sentence
+            is appended to the prompt. If ``prompt`` is also set, this argument is ignored. Defaults to None.
+        prompt (Optional[str], optional): The prompt to use for encoding. For example, if the prompt is "query: ", then the
+            sentence "What is the capital of France?" will be encoded as "query: What is the capital of France?"
+            because the sentence is appended to the prompt. If ``prompt`` is set, ``prompt_name`` is ignored. Defaults to None.
+        batch_size (int): Encode sentences with batch size. (default: 32)
+        chunk_size (int): Sentences are chunked and sent to the individual processes. If None, it determines a
+            sensible size. Defaults to None.
+        show_progress_bar (bool, optional): Whether to output a progress bar when encode sentences. Defaults to None.
+        precision (Literal["float32", "int8", "uint8", "binary", "ubinary"]): The precision to use for the
+            embeddings. Can be "float32", "int8", "uint8", "binary", or "ubinary". All non-float32 precisions
+            are quantized embeddings. Quantized embeddings are smaller in size and faster to compute, but may
+            have lower accuracy. They are useful for reducing the size of the embeddings of a corpus for
+            semantic search, among other tasks. Defaults to "float32".
+        normalize_embeddings (bool): Whether to normalize returned vectors to have length 1. In that case,
+            the faster dot-product (util.dot_score) instead of cosine similarity can be used. Defaults to False.
+
+    Returns:
+        np.ndarray: A 2D numpy array with shape [num_inputs, output_dimension].
+
+    Example:
+        ::
+
+            from sentence_transformers import SentenceTransformer
+
+            def main():
+                model = SentenceTransformer("all-mpnet-base-v2")
+                sentences = ["The weather is so nice!", "It's so sunny outside.", "He's driving to the movie theater.", "She's going to the cinema."] * 1000
+
+                pool = model.start_multi_process_pool()
+                embeddings = model.encode_multi_process(sentences, pool)
+                model.stop_multi_process_pool(pool)
+
+                print(embeddings.shape)
+                # => (4000, 768)
+
+            if __name__ == "__main__":
+                main()
+    """
+
+    if chunk_size is None:
+        chunk_size = min(math.ceil(len(sentences) / len(pool["processes"]) / 10), 5000)
+
+    if show_progress_bar is None:
+        show_progress_bar = logger.getEffectiveLevel() in (logging.INFO, logging.DEBUG)
+
+    logger.debug(f"Chunk data into {math.ceil(len(sentences) / chunk_size)} packages of size {chunk_size}")
+
+    input_queue = pool["input"]
+    last_chunk_id = 0
+    chunk = []
+
+    for sentence in sentences:
+        chunk.append(sentence)
+        if len(chunk) >= chunk_size:
+            input_queue.put(
+                [last_chunk_id, batch_size, chunk, prompt_name, prompt, precision, normalize_embeddings]
+            )
+            last_chunk_id += 1
+            chunk = []
+
+    if len(chunk) > 0:
+        input_queue.put([last_chunk_id, batch_size, chunk, prompt_name, prompt, precision, normalize_embeddings])
+        last_chunk_id += 1
+
+    output_queue = pool["output"]
+    results_list = sorted(
+        [output_queue.get() for _ in trange(last_chunk_id, desc="Chunks", disable=not show_progress_bar)],
+        key=lambda x: x[0],
+    )
+    embeddings = np.concatenate([result[1] for result in results_list])
+    return embeddings
+
+

Encodes a list of sentences using multiple processes and GPUs via +:meth:SentenceTransformer.encode <sentence_transformers.SentenceTransformer.encode>. +The sentences are chunked into smaller packages and sent to individual processes, which encode them on different +GPUs or CPUs. This method is only suitable for encoding large sets of sentences.

+

Args

+
+
sentences : List[str]
+
List of sentences to encode.
+
pool (Dict[Literal["input", "output", "processes"], Any]): A pool of workers started with
+
:meth:SentenceTransformer.start_multi_process_pool <sentence_transformers.SentenceTransformer.start_multi_process_pool>.
+
prompt_name : Optional[str], optional
+
The name of the prompt to use for encoding. Must be a key in the prompts dictionary, +which is either set in the constructor or loaded from the model configuration. For example if +prompt_name is "query" and the prompts is {"query": "query: ", …}, then the sentence "What +is the capital of France?" will be encoded as "query: What is the capital of France?" because the sentence +is appended to the prompt. If prompt is also set, this argument is ignored. Defaults to None.
+
prompt : Optional[str], optional
+
The prompt to use for encoding. For example, if the prompt is "query: ", then the +sentence "What is the capital of France?" will be encoded as "query: What is the capital of France?" +because the sentence is appended to the prompt. If prompt is set, prompt_name is ignored. Defaults to None.
+
batch_size : int
+
Encode sentences with batch size. (default: 32)
+
chunk_size : int
+
Sentences are chunked and sent to the individual processes. If None, it determines a +sensible size. Defaults to None.
+
show_progress_bar : bool, optional
+
Whether to output a progress bar when encode sentences. Defaults to None.
+
precision (Literal["float32", "int8", "uint8", "binary", "ubinary"]): The precision to use for the
+
embeddings. Can be "float32", "int8", "uint8", "binary", or "ubinary". All non-float32 precisions
+
are quantized embeddings. Quantized embeddings are smaller in size and faster to compute, but may
+
have lower accuracy. They are useful for reducing the size of the embeddings of a corpus for
+
semantic search, among other tasks. Defaults to "float32".
+
normalize_embeddings : bool
+
Whether to normalize returned vectors to have length 1. In that case, +the faster dot-product (util.dot_score) instead of cosine similarity can be used. Defaults to False.
+
+

Returns

+
+
np.ndarray
+
A 2D numpy array with shape [num_inputs, output_dimension].
+
+

Example

+

::

+
from sentence_transformers import SentenceTransformer
+
+def main():
+    model = SentenceTransformer("all-mpnet-base-v2")
+    sentences = ["The weather is so nice!", "It's so sunny outside.", "He's driving to the movie theater.", "She's going to the cinema."] * 1000
+
+    pool = model.start_multi_process_pool()
+    embeddings = model.encode_multi_process(sentences, pool)
+    model.stop_multi_process_pool(pool)
+
+    print(embeddings.shape)
+    # => (4000, 768)
+
+if __name__ == "__main__":
+    main()
+
+
+
+def evaluate(self, evaluator: SentenceEvaluator, output_path: str = None) ‑> dict[str, float] | float +
+
+
+ +Expand source code + +
def evaluate(self, evaluator: SentenceEvaluator, output_path: str = None) -> dict[str, float] | float:
+    """
+    Evaluate the model based on an evaluator
+
+    Args:
+        evaluator (SentenceEvaluator): The evaluator used to evaluate the model.
+        output_path (str, optional): The path where the evaluator can write the results. Defaults to None.
+
+    Returns:
+        The evaluation results.
+    """
+    if output_path is not None:
+        os.makedirs(output_path, exist_ok=True)
+    return evaluator(self, output_path)
+
+

Evaluate the model based on an evaluator

+

Args

+
+
evaluator : SentenceEvaluator
+
The evaluator used to evaluate the model.
+
output_path : str, optional
+
The path where the evaluator can write the results. Defaults to None.
+
+

Returns

+

The evaluation results.

+
+
+def forward(self,
input: dict[str, Tensor],
**kwargs) ‑> dict[str, torch.Tensor]
+
+
+
+ +Expand source code + +
def forward(self, input: dict[str, Tensor], **kwargs) -> dict[str, Tensor]:
+    if self.module_kwargs is None:
+        return super().forward(input)
+
+    for module_name, module in self.named_children():
+        module_kwarg_keys = self.module_kwargs.get(module_name, [])
+        module_kwargs = {key: value for key, value in kwargs.items() if key in module_kwarg_keys}
+        input = module(input, **module_kwargs)
+    return input
+
+

Define the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+
+def get_backend(self) ‑> Literal['torch', 'onnx', 'openvino'] +
+
+
+ +Expand source code + +
def get_backend(self) -> Literal["torch", "onnx", "openvino"]:
+    """Return the backend used for inference, which can be one of "torch", "onnx", or "openvino".
+
+    Returns:
+        str: The backend used for inference.
+    """
+    return self.backend
+
+

Return the backend used for inference, which can be one of "torch", "onnx", or "openvino".

+

Returns

+
+
str
+
The backend used for inference.
+
+
+
+def get_max_seq_length(self) ‑> int | None +
+
+
+ +Expand source code + +
def get_max_seq_length(self) -> int | None:
+    """
+    Returns the maximal sequence length that the model accepts. Longer inputs will be truncated.
+
+    Returns:
+        Optional[int]: The maximal sequence length that the model accepts, or None if it is not defined.
+    """
+    if hasattr(self._first_module(), "max_seq_length"):
+        return self._first_module().max_seq_length
+
+    return None
+
+

Returns the maximal sequence length that the model accepts. Longer inputs will be truncated.

+

Returns

+
+
Optional[int]
+
The maximal sequence length that the model accepts, or None if it is not defined.
+
+
+
+def get_sentence_embedding_dimension(self) ‑> int | None +
+
+
+ +Expand source code + +
def get_sentence_embedding_dimension(self) -> int | None:
+    """
+    Returns the number of dimensions in the output of :meth:`SentenceTransformer.encode <sentence_transformers.SentenceTransformer.encode>`.
+
+    Returns:
+        Optional[int]: The number of dimensions in the output of `encode`. If it's not known, it's `None`.
+    """
+    output_dim = None
+    for mod in reversed(self._modules.values()):
+        sent_embedding_dim_method = getattr(mod, "get_sentence_embedding_dimension", None)
+        if callable(sent_embedding_dim_method):
+            output_dim = sent_embedding_dim_method()
+            break
+    if self.truncate_dim is not None:
+        # The user requested truncation. If they set it to a dim greater than output_dim,
+        # no truncation will actually happen. So return output_dim instead of self.truncate_dim
+        return min(output_dim or np.inf, self.truncate_dim)
+    return output_dim
+
+

Returns the number of dimensions in the output of :meth:SentenceTransformer.encode <sentence_transformers.SentenceTransformer.encode>.

+

Returns

+
+
Optional[int]
+
The number of dimensions in the output of encode. If it's not known, it's None.
+
+
+
+def get_sentence_features(self, *features) ‑> dict[typing.Literal['sentence_embedding'], torch.Tensor] +
+
+
+ +Expand source code + +
def get_sentence_features(self, *features) -> dict[Literal["sentence_embedding"], Tensor]:
+    return self._first_module().get_sentence_features(*features)
+
+
+
+
+def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None) ‑> None +
+
+
+ +Expand source code + +
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None) -> None:
+    # Propagate the gradient checkpointing to the transformer model
+    for module in self:
+        if isinstance(module, Transformer):
+            return module.auto_model.gradient_checkpointing_enable(gradient_checkpointing_kwargs)
+
+
+
+
+def push_to_hub(self,
repo_id: str,
token: str | None = None,
private: bool | None = None,
safe_serialization: bool = True,
commit_message: str | None = None,
local_model_path: str | None = None,
exist_ok: bool = False,
replace_model_card: bool = False,
train_datasets: list[str] | None = None,
revision: str | None = None,
create_pr: bool = False) ‑> str
+
+
+
+ +Expand source code + +
    def push_to_hub(
+        self,
+        repo_id: str,
+        token: str | None = None,
+        private: bool | None = None,
+        safe_serialization: bool = True,
+        commit_message: str | None = None,
+        local_model_path: str | None = None,
+        exist_ok: bool = False,
+        replace_model_card: bool = False,
+        train_datasets: list[str] | None = None,
+        revision: str | None = None,
+        create_pr: bool = False,
+    ) -> str:
+        """
+        Uploads all elements of this Sentence Transformer to a new HuggingFace Hub repository.
+
+        Args:
+            repo_id (str): Repository name for your model in the Hub, including the user or organization.
+            token (str, optional): An authentication token (See https://huggingface.co/settings/token)
+            private (bool, optional): Set to true, for hosting a private model
+            safe_serialization (bool, optional): If true, save the model using safetensors. If false, save the model the traditional PyTorch way
+            commit_message (str, optional): Message to commit while pushing.
+            local_model_path (str, optional): Path of the model locally. If set, this file path will be uploaded. Otherwise, the current model will be uploaded
+            exist_ok (bool, optional): If true, saving to an existing repository is OK. If false, saving only to a new repository is possible
+            replace_model_card (bool, optional): If true, replace an existing model card in the hub with the automatically created model card
+            train_datasets (List[str], optional): Datasets used to train the model. If set, the datasets will be added to the model card in the Hub.
+            revision (str, optional): Branch to push the uploaded files to
+            create_pr (bool, optional): If True, create a pull request instead of pushing directly to the main branch
+
+        Returns:
+            str: The url of the commit of your model in the repository on the Hugging Face Hub.
+        """
+        api = HfApi(token=token)
+        repo_url = api.create_repo(
+            repo_id=repo_id,
+            private=private,
+            repo_type=None,
+            exist_ok=exist_ok or create_pr,
+        )
+        repo_id = repo_url.repo_id  # Update the repo_id in case the old repo_id didn't contain a user or organization
+        self.model_card_data.set_model_id(repo_id)
+        if revision is not None:
+            api.create_branch(repo_id=repo_id, branch=revision, exist_ok=True)
+
+        if commit_message is None:
+            backend = self.get_backend()
+            if backend == "torch":
+                commit_message = "Add new SentenceTransformer model"
+            else:
+                commit_message = f"Add new SentenceTransformer model with an {backend} backend"
+
+        commit_description = ""
+        if create_pr:
+            commit_description = f"""\
+Hello!
+
+*This pull request has been automatically generated from the [`push_to_hub`](https://sbert.net/docs/package_reference/sentence_transformer/SentenceTransformer.html#sentence_transformers.SentenceTransformer.push_to_hub) method from the Sentence Transformers library.*
+
+## Full Model Architecture:
+```
+{self}
+```
+
+## Tip:
+Consider testing this pull request before merging by loading the model from this PR with the `revision` argument:
+```python
+from sentence_transformers import SentenceTransformer
+
+# TODO: Fill in the PR number
+pr_number = 2
+model = SentenceTransformer(
+    "{repo_id}",
+    revision=f"refs/pr/{{pr_number}}",
+    backend="{self.get_backend()}",
+)
+
+# Verify that everything works as expected
+embeddings = model.encode(["The weather is lovely today.", "It's so sunny outside!", "He drove to the stadium."])
+print(embeddings.shape)
+
+similarities = model.similarity(embeddings, embeddings)
+print(similarities)
+```
+"""
+
+        if local_model_path:
+            folder_url = api.upload_folder(
+                repo_id=repo_id,
+                folder_path=local_model_path,
+                commit_message=commit_message,
+                commit_description=commit_description,
+                revision=revision,
+                create_pr=create_pr,
+            )
+        else:
+            with tempfile.TemporaryDirectory() as tmp_dir:
+                create_model_card = replace_model_card or not os.path.exists(os.path.join(tmp_dir, "README.md"))
+                self.save_pretrained(
+                    tmp_dir,
+                    model_name=repo_url.repo_id,
+                    create_model_card=create_model_card,
+                    train_datasets=train_datasets,
+                    safe_serialization=safe_serialization,
+                )
+                folder_url = api.upload_folder(
+                    repo_id=repo_id,
+                    folder_path=tmp_dir,
+                    commit_message=commit_message,
+                    commit_description=commit_description,
+                    revision=revision,
+                    create_pr=create_pr,
+                )
+
+        if create_pr:
+            return folder_url.pr_url
+        return folder_url.commit_url
+
+

Uploads all elements of this Sentence Transformer to a new HuggingFace Hub repository.

+

Args

+
+
repo_id : str
+
Repository name for your model in the Hub, including the user or organization.
+
token : str, optional
+
An authentication token (See https://huggingface.co/settings/token)
+
private : bool, optional
+
Set to true, for hosting a private model
+
safe_serialization : bool, optional
+
If true, save the model using safetensors. If false, save the model the traditional PyTorch way
+
commit_message : str, optional
+
Message to commit while pushing.
+
local_model_path : str, optional
+
Path of the model locally. If set, this file path will be uploaded. Otherwise, the current model will be uploaded
+
exist_ok : bool, optional
+
If true, saving to an existing repository is OK. If false, saving only to a new repository is possible
+
replace_model_card : bool, optional
+
If true, replace an existing model card in the hub with the automatically created model card
+
train_datasets : List[str], optional
+
Datasets used to train the model. If set, the datasets will be added to the model card in the Hub.
+
revision : str, optional
+
Branch to push the uploaded files to
+
create_pr : bool, optional
+
If True, create a pull request instead of pushing directly to the main branch
+
+

Returns

+
+
str
+
The url of the commit of your model in the repository on the Hugging Face Hub.
+
+
+
+def save(self,
path: str,
model_name: str | None = None,
create_model_card: bool = True,
train_datasets: list[str] | None = None,
safe_serialization: bool = True) ‑> None
+
+
+
+ +Expand source code + +
def save(
+    self,
+    path: str,
+    model_name: str | None = None,
+    create_model_card: bool = True,
+    train_datasets: list[str] | None = None,
+    safe_serialization: bool = True,
+) -> None:
+    """
+    Saves a model and its configuration files to a directory, so that it can be loaded
+    with ``SentenceTransformer(path)`` again.
+
+    Args:
+        path (str): Path on disc where the model will be saved.
+        model_name (str, optional): Optional model name.
+        create_model_card (bool, optional): If True, create a README.md with basic information about this model.
+        train_datasets (List[str], optional): Optional list with the names of the datasets used to train the model.
+        safe_serialization (bool, optional): If True, save the model using safetensors. If False, save the model
+            the traditional (but unsafe) PyTorch way.
+    """
+    if path is None:
+        return
+
+    os.makedirs(path, exist_ok=True)
+
+    logger.info(f"Save model to {path}")
+    modules_config = []
+
+    # Save some model info
+    self._model_config["__version__"] = {
+        "sentence_transformers": __version__,
+        "transformers": transformers.__version__,
+        "pytorch": torch.__version__,
+    }
+
+    with open(os.path.join(path, "config_sentence_transformers.json"), "w") as fOut:
+        config = self._model_config.copy()
+        config["prompts"] = self.prompts
+        config["default_prompt_name"] = self.default_prompt_name
+        config["similarity_fn_name"] = self.similarity_fn_name
+        json.dump(config, fOut, indent=2)
+
+    # Save modules
+    for idx, name in enumerate(self._modules):
+        module = self._modules[name]
+        if idx == 0 and hasattr(module, "save_in_root"):  # Save first module in the main folder
+            model_path = path + "/"
+        else:
+            model_path = os.path.join(path, str(idx) + "_" + type(module).__name__)
+
+        os.makedirs(model_path, exist_ok=True)
+        # Try to save with safetensors, but fall back to the traditional PyTorch way if the module doesn't support it
+        try:
+            module.save(model_path, safe_serialization=safe_serialization)
+        except TypeError:
+            module.save(model_path)
+
+        # "module" only works for Sentence Transformers as the modules have the same names as the classes
+        class_ref = type(module).__module__
+        # For remote modules, we want to remove "transformers_modules.{repo_name}":
+        if class_ref.startswith("transformers_modules."):
+            class_file = sys.modules[class_ref].__file__
+
+            # Save the custom module file
+            dest_file = Path(model_path) / (Path(class_file).name)
+            shutil.copy(class_file, dest_file)
+
+            # Save all files importeed in the custom module file
+            for needed_file in get_relative_import_files(class_file):
+                dest_file = Path(model_path) / (Path(needed_file).name)
+                shutil.copy(needed_file, dest_file)
+
+            # For remote modules, we want to ignore the "transformers_modules.{repo_id}" part,
+            # i.e. we only want the filename
+            class_ref = f"{class_ref.split('.')[-1]}.{type(module).__name__}"
+        # For other cases, we want to add the class name:
+        elif not class_ref.startswith("sentence_transformers."):
+            class_ref = f"{class_ref}.{type(module).__name__}"
+        modules_config.append({"idx": idx, "name": name, "path": os.path.basename(model_path), "type": class_ref})
+
+    with open(os.path.join(path, "modules.json"), "w") as fOut:
+        json.dump(modules_config, fOut, indent=2)
+
+    # Create model card
+    if create_model_card:
+        self._create_model_card(path, model_name, train_datasets)
+
+

Saves a model and its configuration files to a directory, so that it can be loaded +with SentenceTransformer(path) again.

+

Args

+
+
path : str
+
Path on disc where the model will be saved.
+
model_name : str, optional
+
Optional model name.
+
create_model_card : bool, optional
+
If True, create a README.md with basic information about this model.
+
train_datasets : List[str], optional
+
Optional list with the names of the datasets used to train the model.
+
safe_serialization : bool, optional
+
If True, save the model using safetensors. If False, save the model +the traditional (but unsafe) PyTorch way.
+
+
+
+def save_pretrained(self,
path: str,
model_name: str | None = None,
create_model_card: bool = True,
train_datasets: list[str] | None = None,
safe_serialization: bool = True) ‑> None
+
+
+
+ +Expand source code + +
def save_pretrained(
+    self,
+    path: str,
+    model_name: str | None = None,
+    create_model_card: bool = True,
+    train_datasets: list[str] | None = None,
+    safe_serialization: bool = True,
+) -> None:
+    """
+    Saves a model and its configuration files to a directory, so that it can be loaded
+    with ``SentenceTransformer(path)`` again.
+
+    Args:
+        path (str): Path on disc where the model will be saved.
+        model_name (str, optional): Optional model name.
+        create_model_card (bool, optional): If True, create a README.md with basic information about this model.
+        train_datasets (List[str], optional): Optional list with the names of the datasets used to train the model.
+        safe_serialization (bool, optional): If True, save the model using safetensors. If False, save the model
+            the traditional (but unsafe) PyTorch way.
+    """
+    self.save(
+        path,
+        model_name=model_name,
+        create_model_card=create_model_card,
+        train_datasets=train_datasets,
+        safe_serialization=safe_serialization,
+    )
+
+

Saves a model and its configuration files to a directory, so that it can be loaded +with SentenceTransformer(path) again.

+

Args

+
+
path : str
+
Path on disc where the model will be saved.
+
model_name : str, optional
+
Optional model name.
+
create_model_card : bool, optional
+
If True, create a README.md with basic information about this model.
+
train_datasets : List[str], optional
+
Optional list with the names of the datasets used to train the model.
+
safe_serialization : bool, optional
+
If True, save the model using safetensors. If False, save the model +the traditional (but unsafe) PyTorch way.
+
+
+
+def save_to_hub(self,
repo_id: str,
organization: str | None = None,
token: str | None = None,
private: bool | None = None,
safe_serialization: bool = True,
commit_message: str = 'Add new SentenceTransformer model.',
local_model_path: str | None = None,
exist_ok: bool = False,
replace_model_card: bool = False,
train_datasets: list[str] | None = None) ‑> str
+
+
+
+ +Expand source code + +
@save_to_hub_args_decorator
+def save_to_hub(
+    self,
+    repo_id: str,
+    organization: str | None = None,
+    token: str | None = None,
+    private: bool | None = None,
+    safe_serialization: bool = True,
+    commit_message: str = "Add new SentenceTransformer model.",
+    local_model_path: str | None = None,
+    exist_ok: bool = False,
+    replace_model_card: bool = False,
+    train_datasets: list[str] | None = None,
+) -> str:
+    """
+    DEPRECATED, use `push_to_hub` instead.
+
+    Uploads all elements of this Sentence Transformer to a new HuggingFace Hub repository.
+
+    Args:
+        repo_id (str): Repository name for your model in the Hub, including the user or organization.
+        token (str, optional): An authentication token (See https://huggingface.co/settings/token)
+        private (bool, optional): Set to true, for hosting a private model
+        safe_serialization (bool, optional): If true, save the model using safetensors. If false, save the model the traditional PyTorch way
+        commit_message (str, optional): Message to commit while pushing.
+        local_model_path (str, optional): Path of the model locally. If set, this file path will be uploaded. Otherwise, the current model will be uploaded
+        exist_ok (bool, optional): If true, saving to an existing repository is OK. If false, saving only to a new repository is possible
+        replace_model_card (bool, optional): If true, replace an existing model card in the hub with the automatically created model card
+        train_datasets (List[str], optional): Datasets used to train the model. If set, the datasets will be added to the model card in the Hub.
+
+    Returns:
+        str: The url of the commit of your model in the repository on the Hugging Face Hub.
+    """
+    logger.warning(
+        "The `save_to_hub` method is deprecated and will be removed in a future version of SentenceTransformers."
+        " Please use `push_to_hub` instead for future model uploads."
+    )
+
+    if organization:
+        if "/" not in repo_id:
+            logger.warning(
+                f'Providing an `organization` to `save_to_hub` is deprecated, please use `repo_id="{organization}/{repo_id}"` instead.'
+            )
+            repo_id = f"{organization}/{repo_id}"
+        elif repo_id.split("/")[0] != organization:
+            raise ValueError(
+                "Providing an `organization` to `save_to_hub` is deprecated, please only use `repo_id`."
+            )
+        else:
+            logger.warning(
+                f'Providing an `organization` to `save_to_hub` is deprecated, please only use `repo_id="{repo_id}"` instead.'
+            )
+
+    return self.push_to_hub(
+        repo_id=repo_id,
+        token=token,
+        private=private,
+        safe_serialization=safe_serialization,
+        commit_message=commit_message,
+        local_model_path=local_model_path,
+        exist_ok=exist_ok,
+        replace_model_card=replace_model_card,
+        train_datasets=train_datasets,
+    )
+
+

DEPRECATED, use push_to_hub instead.

+

Uploads all elements of this Sentence Transformer to a new HuggingFace Hub repository.

+

Args

+
+
repo_id : str
+
Repository name for your model in the Hub, including the user or organization.
+
token : str, optional
+
An authentication token (See https://huggingface.co/settings/token)
+
private : bool, optional
+
Set to true, for hosting a private model
+
safe_serialization : bool, optional
+
If true, save the model using safetensors. If false, save the model the traditional PyTorch way
+
commit_message : str, optional
+
Message to commit while pushing.
+
local_model_path : str, optional
+
Path of the model locally. If set, this file path will be uploaded. Otherwise, the current model will be uploaded
+
exist_ok : bool, optional
+
If true, saving to an existing repository is OK. If false, saving only to a new repository is possible
+
replace_model_card : bool, optional
+
If true, replace an existing model card in the hub with the automatically created model card
+
train_datasets : List[str], optional
+
Datasets used to train the model. If set, the datasets will be added to the model card in the Hub.
+
+

Returns

+
+
str
+
The url of the commit of your model in the repository on the Hugging Face Hub.
+
+
+
+def set_pooling_include_prompt(self, include_prompt: bool) ‑> None +
+
+
+ +Expand source code + +
def set_pooling_include_prompt(self, include_prompt: bool) -> None:
+    """
+    Sets the `include_prompt` attribute in the pooling layer in the model, if there is one.
+
+    This is useful for INSTRUCTOR models, as the prompt should be excluded from the pooling strategy
+    for these models.
+
+    Args:
+        include_prompt (bool): Whether to include the prompt in the pooling layer.
+
+    Returns:
+        None
+    """
+    for module in self:
+        if isinstance(module, Pooling):
+            module.include_prompt = include_prompt
+            break
+
+

Sets the include_prompt attribute in the pooling layer in the model, if there is one.

+

This is useful for INSTRUCTOR models, as the prompt should be excluded from the pooling strategy +for these models.

+

Args

+
+
include_prompt : bool
+
Whether to include the prompt in the pooling layer.
+
+

Returns

+

None

+
+
+def start_multi_process_pool(self, target_devices: list[str] = None) ‑> dict[typing.Literal['input', 'output', 'processes'], typing.Any] +
+
+
+ +Expand source code + +
def start_multi_process_pool(
+    self, target_devices: list[str] = None
+) -> dict[Literal["input", "output", "processes"], Any]:
+    """
+    Starts a multi-process pool to process the encoding with several independent processes
+    via :meth:`SentenceTransformer.encode_multi_process <sentence_transformers.SentenceTransformer.encode_multi_process>`.
+
+    This method is recommended if you want to encode on multiple GPUs or CPUs. It is advised
+    to start only one process per GPU. This method works together with encode_multi_process
+    and stop_multi_process_pool.
+
+    Args:
+        target_devices (List[str], optional): PyTorch target devices, e.g. ["cuda:0", "cuda:1", ...],
+            ["npu:0", "npu:1", ...], or ["cpu", "cpu", "cpu", "cpu"]. If target_devices is None and CUDA/NPU
+            is available, then all available CUDA/NPU devices will be used. If target_devices is None and
+            CUDA/NPU is not available, then 4 CPU devices will be used.
+
+    Returns:
+        Dict[str, Any]: A dictionary with the target processes, an input queue, and an output queue.
+    """
+    if target_devices is None:
+        if torch.cuda.is_available():
+            target_devices = [f"cuda:{i}" for i in range(torch.cuda.device_count())]
+        elif is_torch_npu_available():
+            target_devices = [f"npu:{i}" for i in range(torch.npu.device_count())]
+        else:
+            logger.info("CUDA/NPU is not available. Starting 4 CPU workers")
+            target_devices = ["cpu"] * 4
+
+    logger.info("Start multi-process pool on devices: {}".format(", ".join(map(str, target_devices))))
+
+    self.to("cpu")
+    self.share_memory()
+    ctx = mp.get_context("spawn")
+    input_queue = ctx.Queue()
+    output_queue = ctx.Queue()
+    processes = []
+
+    for device_id in target_devices:
+        p = ctx.Process(
+            target=SentenceTransformer._encode_multi_process_worker,
+            args=(device_id, self, input_queue, output_queue),
+            daemon=True,
+        )
+        p.start()
+        processes.append(p)
+
+    return {"input": input_queue, "output": output_queue, "processes": processes}
+
+

Starts a multi-process pool to process the encoding with several independent processes +via :meth:SentenceTransformer.encode_multi_process <sentence_transformers.SentenceTransformer.encode_multi_process>.

+

This method is recommended if you want to encode on multiple GPUs or CPUs. It is advised +to start only one process per GPU. This method works together with encode_multi_process +and stop_multi_process_pool.

+

Args

+
+
target_devices : List[str], optional
+
PyTorch target devices, e.g. ["cuda:0", "cuda:1", …], +["npu:0", "npu:1", …], or ["cpu", "cpu", "cpu", "cpu"]. If target_devices is None and CUDA/NPU +is available, then all available CUDA/NPU devices will be used. If target_devices is None and +CUDA/NPU is not available, then 4 CPU devices will be used.
+
+

Returns

+
+
Dict[str, Any]
+
A dictionary with the target processes, an input queue, and an output queue.
+
+
+
+def tokenize(self, texts: list[str] | list[dict] | list[tuple[str, str]]) ‑> dict[str, torch.Tensor] +
+
+
+ +Expand source code + +
def tokenize(self, texts: list[str] | list[dict] | list[tuple[str, str]]) -> dict[str, Tensor]:
+    """
+    Tokenizes the texts.
+
+    Args:
+        texts (Union[List[str], List[Dict], List[Tuple[str, str]]]): A list of texts to be tokenized.
+
+    Returns:
+        Dict[str, Tensor]: A dictionary of tensors with the tokenized texts. Common keys are "input_ids",
+            "attention_mask", and "token_type_ids".
+    """
+    return self._first_module().tokenize(texts)
+
+

Tokenizes the texts.

+

Args

+
+
texts : Union[List[str], List[Dict], List[Tuple[str, str]]]
+
A list of texts to be tokenized.
+
+

Returns

+
+
Dict[str, Tensor]
+
A dictionary of tensors with the tokenized texts. Common keys are "input_ids", +"attention_mask", and "token_type_ids".
+
+
+
+def truncate_sentence_embeddings(self, truncate_dim: int | None) ‑> Iterator[None] +
+
+
+ +Expand source code + +
@contextmanager
+def truncate_sentence_embeddings(self, truncate_dim: int | None) -> Iterator[None]:
+    """
+    In this context, :meth:`SentenceTransformer.encode <sentence_transformers.SentenceTransformer.encode>` outputs
+    sentence embeddings truncated at dimension ``truncate_dim``.
+
+    This may be useful when you are using the same model for different applications where different dimensions
+    are needed.
+
+    Args:
+        truncate_dim (int, optional): The dimension to truncate sentence embeddings to. ``None`` does no truncation.
+
+    Example:
+        ::
+
+            from sentence_transformers import SentenceTransformer
+
+            model = SentenceTransformer("all-mpnet-base-v2")
+
+            with model.truncate_sentence_embeddings(truncate_dim=16):
+                embeddings_truncated = model.encode(["hello there", "hiya"])
+            assert embeddings_truncated.shape[-1] == 16
+    """
+    original_output_dim = self.truncate_dim
+    try:
+        self.truncate_dim = truncate_dim
+        yield
+    finally:
+        self.truncate_dim = original_output_dim
+
+

In this context, :meth:SentenceTransformer.encode <sentence_transformers.SentenceTransformer.encode> outputs +sentence embeddings truncated at dimension truncate_dim.

+

This may be useful when you are using the same model for different applications where different dimensions +are needed.

+

Args

+
+
truncate_dim : int, optional
+
The dimension to truncate sentence embeddings to. None does no truncation.
+
+

Example

+

::

+
from sentence_transformers import SentenceTransformer
+
+model = SentenceTransformer("all-mpnet-base-v2")
+
+with model.truncate_sentence_embeddings(truncate_dim=16):
+    embeddings_truncated = model.encode(["hello there", "hiya"])
+assert embeddings_truncated.shape[-1] == 16
+
+
+
+
+
+class Tensor +(...) +
+
+
+

Ancestors

+
    +
  • torch._C.TensorBase
  • +
+

Subclasses

+
    +
  • torch._subclasses.fake_tensor.FakeTensor
  • +
  • torch._subclasses.functional_tensor.FunctionalTensor
  • +
  • torch.masked.maskedtensor.core.MaskedTensor
  • +
  • torch.nn.parameter.Buffer
  • +
  • torch.nn.parameter.Parameter
  • +
  • torch.nn.parameter.UninitializedBuffer
  • +
  • torch.sparse.semi_structured.SparseSemiStructuredTensor
  • +
  • torch.testing._internal.logging_tensor.LoggingTensor
  • +
+

Methods

+
+
+def align_to(self, *names) +
+
+
+ +Expand source code + +
def align_to(self, *names):
+    r"""Permutes the dimensions of the :attr:`self` tensor to match the order
+    specified in :attr:`names`, adding size-one dims for any new names.
+
+    All of the dims of :attr:`self` must be named in order to use this method.
+    The resulting tensor is a view on the original tensor.
+
+    All dimension names of :attr:`self` must be present in :attr:`names`.
+    :attr:`names` may contain additional names that are not in ``self.names``;
+    the output tensor has a size-one dimension for each of those new names.
+
+    :attr:`names` may contain up to one Ellipsis (``...``).
+    The Ellipsis is expanded to be equal to all dimension names of :attr:`self`
+    that are not mentioned in :attr:`names`, in the order that they appear
+    in :attr:`self`.
+
+    Python 2 does not support Ellipsis but one may use a string literal
+    instead (``'...'``).
+
+    Args:
+        names (iterable of str): The desired dimension ordering of the
+            output tensor. May contain up to one Ellipsis that is expanded
+            to all unmentioned dim names of :attr:`self`.
+
+    Examples::
+
+        >>> tensor = torch.randn(2, 2, 2, 2, 2, 2)
+        >>> named_tensor = tensor.refine_names('A', 'B', 'C', 'D', 'E', 'F')
+
+        # Move the F and E dims to the front while keeping the rest in order
+        >>> named_tensor.align_to('F', 'E', ...)
+
+    .. warning::
+        The named tensor API is experimental and subject to change.
+
+    """
+    if has_torch_function_unary(self):
+        return handle_torch_function(Tensor.align_to, (self,), self, *names)
+    ellipsis_idx = single_ellipsis_index(names, "align_to")
+    if ellipsis_idx is None:
+        return super().align_to(names)
+    return super().align_to(
+        [name for name in names if not is_ellipsis(name)], ellipsis_idx
+    )
+
+

Permutes the dimensions of the :attr:self tensor to match the order +specified in :attr:names, adding size-one dims for any new names.

+

All of the dims of :attr:self must be named in order to use this method. +The resulting tensor is a view on the original tensor.

+

All dimension names of :attr:self must be present in :attr:names. +:attr:names may contain additional names that are not in self.names; +the output tensor has a size-one dimension for each of those new names.

+

:attr:names may contain up to one Ellipsis (). +The Ellipsis is expanded to be equal to all dimension names of :attr:self +that are not mentioned in :attr:names, in the order that they appear +in :attr:self.

+

Python 2 does not support Ellipsis but one may use a string literal +instead ('...').

+

Args

+
+
names : iterable of str
+
The desired dimension ordering of the +output tensor. May contain up to one Ellipsis that is expanded +to all unmentioned dim names of :attr:self.
+
+

Examples::

+
>>> tensor = torch.randn(2, 2, 2, 2, 2, 2)
+>>> named_tensor = tensor.refine_names('A', 'B', 'C', 'D', 'E', 'F')
+
+# Move the F and E dims to the front while keeping the rest in order
+>>> named_tensor.align_to('F', 'E', ...)
+
+
+

Warning

+

The named tensor API is experimental and subject to change.

+
+
+
+def backward(self, gradient=None, retain_graph=None, create_graph=False, inputs=None) +
+
+
+ +Expand source code + +
def backward(
+    self, gradient=None, retain_graph=None, create_graph=False, inputs=None
+):
+    r"""Computes the gradient of current tensor wrt graph leaves.
+
+    The graph is differentiated using the chain rule. If the tensor is
+    non-scalar (i.e. its data has more than one element) and requires
+    gradient, the function additionally requires specifying a ``gradient``.
+    It should be a tensor of matching type and shape, that represents
+    the gradient of the differentiated function w.r.t. ``self``.
+
+    This function accumulates gradients in the leaves - you might need to zero
+    ``.grad`` attributes or set them to ``None`` before calling it.
+    See :ref:`Default gradient layouts<default-grad-layouts>`
+    for details on the memory layout of accumulated gradients.
+
+    .. note::
+
+        If you run any forward ops, create ``gradient``, and/or call ``backward``
+        in a user-specified CUDA stream context, see
+        :ref:`Stream semantics of backward passes<bwd-cuda-stream-semantics>`.
+
+    .. note::
+
+        When ``inputs`` are provided and a given input is not a leaf,
+        the current implementation will call its grad_fn (though it is not strictly needed to get this gradients).
+        It is an implementation detail on which the user should not rely.
+        See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details.
+
+    Args:
+        gradient (Tensor, optional): The gradient of the function
+            being differentiated w.r.t. ``self``.
+            This argument can be omitted if ``self`` is a scalar.
+        retain_graph (bool, optional): If ``False``, the graph used to compute
+            the grads will be freed. Note that in nearly all cases setting
+            this option to True is not needed and often can be worked around
+            in a much more efficient way. Defaults to the value of
+            ``create_graph``.
+        create_graph (bool, optional): If ``True``, graph of the derivative will
+            be constructed, allowing to compute higher order derivative
+            products. Defaults to ``False``.
+        inputs (sequence of Tensor, optional): Inputs w.r.t. which the gradient will be
+            accumulated into ``.grad``. All other tensors will be ignored. If not
+            provided, the gradient is accumulated into all the leaf Tensors that were
+            used to compute the :attr:`tensors`.
+    """
+    if has_torch_function_unary(self):
+        return handle_torch_function(
+            Tensor.backward,
+            (self,),
+            self,
+            gradient=gradient,
+            retain_graph=retain_graph,
+            create_graph=create_graph,
+            inputs=inputs,
+        )
+    torch.autograd.backward(
+        self, gradient, retain_graph, create_graph, inputs=inputs
+    )
+
+

Computes the gradient of current tensor wrt graph leaves.

+

The graph is differentiated using the chain rule. If the tensor is +non-scalar (i.e. its data has more than one element) and requires +gradient, the function additionally requires specifying a gradient. +It should be a tensor of matching type and shape, that represents +the gradient of the differentiated function w.r.t. self.

+

This function accumulates gradients in the leaves - you might need to zero +.grad attributes or set them to None before calling it. +See :ref:Default gradient layouts<default-grad-layouts> +for details on the memory layout of accumulated gradients.

+
+

Note

+

If you run any forward ops, create gradient, and/or call backward +in a user-specified CUDA stream context, see +:ref:Stream semantics of backward passes<bwd-cuda-stream-semantics>.

+
+
+

Note

+

When inputs are provided and a given input is not a leaf, +the current implementation will call its grad_fn (though it is not strictly needed to get this gradients). +It is an implementation detail on which the user should not rely. +See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details.

+
+

Args

+
+
gradient : Tensor, optional
+
The gradient of the function +being differentiated w.r.t. self. +This argument can be omitted if self is a scalar.
+
retain_graph : bool, optional
+
If False, the graph used to compute +the grads will be freed. Note that in nearly all cases setting +this option to True is not needed and often can be worked around +in a much more efficient way. Defaults to the value of +create_graph.
+
create_graph : bool, optional
+
If True, graph of the derivative will +be constructed, allowing to compute higher order derivative +products. Defaults to False.
+
inputs : sequence of Tensor, optional
+
Inputs w.r.t. which the gradient will be +accumulated into .grad. All other tensors will be ignored. If not +provided, the gradient is accumulated into all the leaf Tensors that were +used to compute the :attr:tensors.
+
+
+
+def detach(...) +
+
+

Returns a new Tensor, detached from the current graph.

+

The result will never require gradient.

+

This method also affects forward mode AD gradients and the result will never +have forward mode AD gradients.

+
+

Note

+

Returned Tensor shares the same storage with the original one. +In-place modifications on either of them will be seen, and may trigger +errors in correctness checks.

+
+
+
+def detach_(...) +
+
+

Detaches the Tensor from the graph that created it, making it a leaf. +Views cannot be detached in-place.

+

This method also affects forward mode AD gradients and the result will never +have forward mode AD gradients.

+
+
+def dim_order(self) +
+
+
+ +Expand source code + +
def dim_order(self):
+    """
+
+    dim_order() -> tuple
+
+    Returns a tuple of int describing the dim order or physical layout of :attr:`self`.
+
+    Args:
+        None
+
+    Dim order represents how dimensions are laid out in memory,
+    starting from the outermost to the innermost dimension.
+
+    Example::
+        >>> torch.empty((2, 3, 5, 7)).dim_order()
+        (0, 1, 2, 3)
+        >>> torch.empty((2, 3, 5, 7), memory_format=torch.channels_last).dim_order()
+        (0, 2, 3, 1)
+
+    .. warning::
+        The dim_order tensor API is experimental and subject to change.
+
+    """
+    if has_torch_function_unary(self):
+        return handle_torch_function(Tensor.dim_order, (self,), self)
+
+    import torch._prims_common as utils
+
+    return tuple(utils.compute_elementwise_output_logical_to_physical_perm(self))
+
+

dim_order() -> tuple

+

Returns a tuple of int describing the dim order or physical layout of :attr:self.

+

Args

+

None +Dim order represents how dimensions are laid out in memory, +starting from the outermost to the innermost dimension.

+

Example:: +>>> torch.empty((2, 3, 5, 7)).dim_order() +(0, 1, 2, 3) +>>> torch.empty((2, 3, 5, 7), memory_format=torch.channels_last).dim_order() +(0, 2, 3, 1)

+
+

Warning

+

The dim_order tensor API is experimental and subject to change.

+
+
+
+def eig(self, eigenvectors=False) +
+
+
+ +Expand source code + +
def eig(self, eigenvectors=False):
+    from torch._linalg_utils import eig
+
+    return eig(self, eigenvectors=eigenvectors)
+
+
+
+
+def is_shared(self) +
+
+
+ +Expand source code + +
def is_shared(self):
+    r"""Checks if tensor is in shared memory.
+
+    This is always ``True`` for CUDA tensors.
+    """
+    if has_torch_function_unary(self):
+        return handle_torch_function(Tensor.is_shared, (self,), self)
+    return self._typed_storage()._is_shared()
+
+

Checks if tensor is in shared memory.

+

This is always True for CUDA tensors.

+
+
+def istft(self,
n_fft: int,
hop_length: int | None = None,
win_length: int | None = None,
window: Optional[Tensor] = None,
center: bool = True,
normalized: bool = False,
onesided: bool | None = None,
length: int | None = None,
return_complex: bool = False)
+
+
+
+ +Expand source code + +
def istft(
+    self,
+    n_fft: int,
+    hop_length: Optional[int] = None,
+    win_length: Optional[int] = None,
+    window: "Optional[Tensor]" = None,
+    center: bool = True,
+    normalized: bool = False,
+    onesided: Optional[bool] = None,
+    length: Optional[int] = None,
+    return_complex: bool = False,
+):
+    r"""See :func:`torch.istft`"""
+    if has_torch_function_unary(self):
+        return handle_torch_function(
+            Tensor.istft,
+            (self,),
+            self,
+            n_fft,
+            hop_length=hop_length,
+            win_length=win_length,
+            window=window,
+            center=center,
+            normalized=normalized,
+            onesided=onesided,
+            length=length,
+            return_complex=return_complex,
+        )
+    return torch.istft(
+        self,
+        n_fft,
+        hop_length,
+        win_length,
+        window,
+        center,
+        normalized,
+        onesided,
+        length,
+        return_complex=return_complex,
+    )
+
+

See :func:torch.istft

+
+
+def lstsq(self, other) +
+
+
+ +Expand source code + +
def lstsq(self, other):
+    from torch._linalg_utils import lstsq
+
+    return lstsq(self, other)
+
+
+
+
+def lu(self, pivot=True, get_infos=False) +
+
+
+ +Expand source code + +
def lu(self, pivot=True, get_infos=False):
+    r"""See :func:`torch.lu`"""
+    # If get_infos is True, then we don't need to check for errors and vice versa
+    if has_torch_function_unary(self):
+        return handle_torch_function(
+            Tensor.lu, (self,), self, pivot=pivot, get_infos=get_infos
+        )
+
+    LU, pivots, infos = torch._lu_with_info(
+        self, pivot=pivot, check_errors=(not get_infos)
+    )
+    if get_infos:
+        return LU, pivots, infos
+    else:
+        return LU, pivots
+
+

See :func:torch.lu

+
+
+def module_load(self, other, assign=False) +
+
+
+ +Expand source code + +
def module_load(self, other, assign=False):
+    r"""Defines how to transform ``other`` when loading it into ``self`` in :meth:`~nn.Module.load_state_dict`.
+
+    Used when :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``.
+
+    It is expected that ``self`` is a parameter or buffer in an ``nn.Module`` and ``other`` is the
+    value in the state dictionary with the corresponding key, this method defines
+    how ``other`` is remapped before being swapped with ``self`` via
+    :func:`~torch.utils.swap_tensors` in :meth:`~nn.Module.load_state_dict`.
+
+    .. note::
+        This method should always return a new object that is not ``self`` or ``other``.
+        For example, the default implementation returns ``self.copy_(other).detach()``
+        if ``assign`` is ``False`` or ``other.detach()`` if ``assign`` is ``True``.
+
+    Args:
+        other (Tensor): value in state dict with key corresponding to ``self``
+        assign (bool): the assign argument passed to :meth:`nn.Module.load_state_dict`
+
+    """
+    if has_torch_function_variadic(self, other):
+        return handle_torch_function(
+            Tensor.module_load, (self, other), self, other, assign=assign
+        )
+
+    if assign:
+        return other.detach()
+    else:
+        return self.copy_(other).detach()
+
+

Defines how to transform other when loading it into self in :meth:~nn.Module.load_state_dict.

+

Used when :func:~torch.__future__.get_swap_module_params_on_conversion is True.

+

It is expected that self is a parameter or buffer in an nn.Module and other is the +value in the state dictionary with the corresponding key, this method defines +how other is remapped before being swapped with self via +:func:~torch.utils.swap_tensors in :meth:~nn.Module.load_state_dict.

+
+

Note

+

This method should always return a new object that is not self or other. +For example, the default implementation returns self.copy_(other).detach() +if assign is False or other.detach() if assign is True.

+
+

Args

+
+
other : Tensor
+
value in state dict with key corresponding to self
+
assign : bool
+
the assign argument passed to :meth:nn.Module.load_state_dict
+
+
+
+def norm(self, p: float | str | None = 'fro', dim=None, keepdim=False, dtype=None) +
+
+
+ +Expand source code + +
def norm(
+    self,
+    p: Optional[Union[float, str]] = "fro",
+    dim=None,
+    keepdim=False,
+    dtype=None,
+):
+    r"""See :func:`torch.norm`"""
+    if has_torch_function_unary(self):
+        return handle_torch_function(
+            Tensor.norm, (self,), self, p=p, dim=dim, keepdim=keepdim, dtype=dtype
+        )
+    return torch.norm(self, p, dim, keepdim, dtype=dtype)
+
+

See :func:torch.norm

+
+
+def refine_names(self, *names) +
+
+
+ +Expand source code + +
def refine_names(self, *names):
+    r"""Refines the dimension names of :attr:`self` according to :attr:`names`.
+
+    Refining is a special case of renaming that "lifts" unnamed dimensions.
+    A ``None`` dim can be refined to have any name; a named dim can only be
+    refined to have the same name.
+
+    Because named tensors can coexist with unnamed tensors, refining names
+    gives a nice way to write named-tensor-aware code that works with both
+    named and unnamed tensors.
+
+    :attr:`names` may contain up to one Ellipsis (``...``).
+    The Ellipsis is expanded greedily; it is expanded in-place to fill
+    :attr:`names` to the same length as ``self.dim()`` using names from the
+    corresponding indices of ``self.names``.
+
+    Python 2 does not support Ellipsis but one may use a string literal
+    instead (``'...'``).
+
+    Args:
+        names (iterable of str): The desired names of the output tensor. May
+            contain up to one Ellipsis.
+
+    Examples::
+
+        >>> imgs = torch.randn(32, 3, 128, 128)
+        >>> named_imgs = imgs.refine_names('N', 'C', 'H', 'W')
+        >>> named_imgs.names
+        ('N', 'C', 'H', 'W')
+
+        >>> tensor = torch.randn(2, 3, 5, 7, 11)
+        >>> tensor = tensor.refine_names('A', ..., 'B', 'C')
+        >>> tensor.names
+        ('A', None, None, 'B', 'C')
+
+    .. warning::
+        The named tensor API is experimental and subject to change.
+
+    """
+    if has_torch_function_unary(self):
+        return handle_torch_function(Tensor.refine_names, (self,), self, *names)
+    names = resolve_ellipsis(names, self.names, "refine_names")
+    return super().refine_names(names)
+
+

Refines the dimension names of :attr:self according to :attr:names.

+

Refining is a special case of renaming that "lifts" unnamed dimensions. +A None dim can be refined to have any name; a named dim can only be +refined to have the same name.

+

Because named tensors can coexist with unnamed tensors, refining names +gives a nice way to write named-tensor-aware code that works with both +named and unnamed tensors.

+

:attr:names may contain up to one Ellipsis (). +The Ellipsis is expanded greedily; it is expanded in-place to fill +:attr:names to the same length as self.dim() using names from the +corresponding indices of self.names.

+

Python 2 does not support Ellipsis but one may use a string literal +instead ('...').

+

Args

+
+
names : iterable of str
+
The desired names of the output tensor. May +contain up to one Ellipsis.
+
+

Examples::

+
>>> imgs = torch.randn(32, 3, 128, 128)
+>>> named_imgs = imgs.refine_names('N', 'C', 'H', 'W')
+>>> named_imgs.names
+('N', 'C', 'H', 'W')
+
+>>> tensor = torch.randn(2, 3, 5, 7, 11)
+>>> tensor = tensor.refine_names('A', ..., 'B', 'C')
+>>> tensor.names
+('A', None, None, 'B', 'C')
+
+
+

Warning

+

The named tensor API is experimental and subject to change.

+
+
+
+def register_hook(self, hook) +
+
+
+ +Expand source code + +
def register_hook(self, hook):
+    r"""Registers a backward hook.
+
+    The hook will be called every time a gradient with respect to the
+    Tensor is computed. The hook should have the following signature::
+
+        hook(grad) -> Tensor or None
+
+
+    The hook should not modify its argument, but it can optionally return
+    a new gradient which will be used in place of :attr:`grad`.
+
+    This function returns a handle with a method ``handle.remove()``
+    that removes the hook from the module.
+
+    .. note::
+        See :ref:`backward-hooks-execution` for more information on how when this hook
+        is executed, and how its execution is ordered relative to other hooks.
+
+    Example::
+
+        >>> v = torch.tensor([0., 0., 0.], requires_grad=True)
+        >>> h = v.register_hook(lambda grad: grad * 2)  # double the gradient
+        >>> v.backward(torch.tensor([1., 2., 3.]))
+        >>> v.grad
+
+         2
+         4
+         6
+        [torch.FloatTensor of size (3,)]
+
+        >>> h.remove()  # removes the hook
+    """
+    if has_torch_function_unary(self):
+        return handle_torch_function(Tensor.register_hook, (self,), self, hook)
+    if not self.requires_grad:
+        raise RuntimeError(
+            "cannot register a hook on a tensor that doesn't require gradient"
+        )
+    if self._backward_hooks is None:
+        self._backward_hooks = OrderedDict()
+        if self.grad_fn is not None:
+            self.grad_fn._register_hook_dict(self)
+
+    from torch.utils.hooks import RemovableHandle
+
+    handle = RemovableHandle(self._backward_hooks)
+    self._backward_hooks[handle.id] = hook
+    return handle
+
+

Registers a backward hook.

+

The hook will be called every time a gradient with respect to the +Tensor is computed. The hook should have the following signature::

+
hook(grad) -> Tensor or None
+
+

The hook should not modify its argument, but it can optionally return +a new gradient which will be used in place of :attr:grad.

+

This function returns a handle with a method handle.remove() +that removes the hook from the module.

+
+

Note

+

See :ref:backward-hooks-execution for more information on how when this hook +is executed, and how its execution is ordered relative to other hooks.

+
+

Example::

+
>>> v = torch.tensor([0., 0., 0.], requires_grad=True)
+>>> h = v.register_hook(lambda grad: grad * 2)  # double the gradient
+>>> v.backward(torch.tensor([1., 2., 3.]))
+>>> v.grad
+
+ 2
+ 4
+ 6
+[torch.FloatTensor of size (3,)]
+
+>>> h.remove()  # removes the hook
+
+
+
+def register_post_accumulate_grad_hook(self, hook) +
+
+
+ +Expand source code + +
def register_post_accumulate_grad_hook(self, hook):
+    r"""Registers a backward hook that runs after grad accumulation.
+
+    The hook will be called after all gradients for a tensor have been accumulated,
+    meaning that the .grad field has been updated on that tensor. The post
+    accumulate grad hook is ONLY applicable for leaf tensors (tensors without a
+    .grad_fn field). Registering this hook on a non-leaf tensor will error!
+
+    The hook should have the following signature::
+
+        hook(param: Tensor) -> None
+
+    Note that, unlike other autograd hooks, this hook operates on the tensor
+    that requires grad and not the grad itself. The hook can in-place modify
+    and access its Tensor argument, including its .grad field.
+
+    This function returns a handle with a method ``handle.remove()``
+    that removes the hook from the module.
+
+    .. note::
+        See :ref:`backward-hooks-execution` for more information on how when this hook
+        is executed, and how its execution is ordered relative to other hooks. Since
+        this hook runs during the backward pass, it will run in no_grad mode (unless
+        create_graph is True). You can use torch.enable_grad() to re-enable autograd
+        within the hook if you need it.
+
+    Example::
+
+        >>> v = torch.tensor([0., 0., 0.], requires_grad=True)
+        >>> lr = 0.01
+        >>> # simulate a simple SGD update
+        >>> h = v.register_post_accumulate_grad_hook(lambda p: p.add_(p.grad, alpha=-lr))
+        >>> v.backward(torch.tensor([1., 2., 3.]))
+        >>> v
+        tensor([-0.0100, -0.0200, -0.0300], requires_grad=True)
+
+        >>> h.remove()  # removes the hook
+    """
+    if has_torch_function_unary(self):
+        return handle_torch_function(
+            Tensor.register_post_accumulate_grad_hook, (self,), self, hook
+        )
+    if not self.requires_grad:
+        raise RuntimeError(
+            "cannot register a hook on a tensor that doesn't require gradient"
+        )
+    if self.grad_fn is not None:
+        raise RuntimeError(
+            "post accumulate grad hooks cannot be registered on non-leaf tensors"
+        )
+    if self._post_accumulate_grad_hooks is None:
+        self._post_accumulate_grad_hooks: Dict[Any, Any] = OrderedDict()
+
+    from torch.utils.hooks import RemovableHandle
+
+    handle = RemovableHandle(self._post_accumulate_grad_hooks)
+    self._post_accumulate_grad_hooks[handle.id] = hook
+    return handle
+
+

Registers a backward hook that runs after grad accumulation.

+

The hook will be called after all gradients for a tensor have been accumulated, +meaning that the .grad field has been updated on that tensor. The post +accumulate grad hook is ONLY applicable for leaf tensors (tensors without a +.grad_fn field). Registering this hook on a non-leaf tensor will error!

+

The hook should have the following signature::

+
hook(param: Tensor) -> None
+
+

Note that, unlike other autograd hooks, this hook operates on the tensor +that requires grad and not the grad itself. The hook can in-place modify +and access its Tensor argument, including its .grad field.

+

This function returns a handle with a method handle.remove() +that removes the hook from the module.

+
+

Note

+

See :ref:backward-hooks-execution for more information on how when this hook +is executed, and how its execution is ordered relative to other hooks. Since +this hook runs during the backward pass, it will run in no_grad mode (unless +create_graph is True). You can use torch.enable_grad() to re-enable autograd +within the hook if you need it.

+
+

Example::

+
>>> v = torch.tensor([0., 0., 0.], requires_grad=True)
+>>> lr = 0.01
+>>> # simulate a simple SGD update
+>>> h = v.register_post_accumulate_grad_hook(lambda p: p.add_(p.grad, alpha=-lr))
+>>> v.backward(torch.tensor([1., 2., 3.]))
+>>> v
+tensor([-0.0100, -0.0200, -0.0300], requires_grad=True)
+
+>>> h.remove()  # removes the hook
+
+
+
+def reinforce(self, reward) +
+
+
+ +Expand source code + +
def reinforce(self, reward):
+    def trim(str):
+        return "\n".join([line.strip() for line in str.split("\n")])
+
+    raise RuntimeError(
+        trim(
+            r"""reinforce() was removed.
+        Use torch.distributions instead.
+        See https://pytorch.org/docs/main/distributions.html
+
+        Instead of:
+
+        probs = policy_network(state)
+        action = probs.multinomial()
+        next_state, reward = env.step(action)
+        action.reinforce(reward)
+        action.backward()
+
+        Use:
+
+        probs = policy_network(state)
+        # NOTE: categorical is equivalent to what used to be called multinomial
+        m = torch.distributions.Categorical(probs)
+        action = m.sample()
+        next_state, reward = env.step(action)
+        loss = -m.log_prob(action) * reward
+        loss.backward()
+    """
+        )
+    )
+
+
+
+
+def rename(self, *names, **rename_map) +
+
+
+ +Expand source code + +
def rename(self, *names, **rename_map):
+    """Renames dimension names of :attr:`self`.
+
+    There are two main usages:
+
+    ``self.rename(**rename_map)`` returns a view on tensor that has dims
+    renamed as specified in the mapping :attr:`rename_map`.
+
+    ``self.rename(*names)`` returns a view on tensor, renaming all
+    dimensions positionally using :attr:`names`.
+    Use ``self.rename(None)`` to drop names on a tensor.
+
+    One cannot specify both positional args :attr:`names` and keyword args
+    :attr:`rename_map`.
+
+    Examples::
+
+        >>> imgs = torch.rand(2, 3, 5, 7, names=('N', 'C', 'H', 'W'))
+        >>> renamed_imgs = imgs.rename(N='batch', C='channels')
+        >>> renamed_imgs.names
+        ('batch', 'channels', 'H', 'W')
+
+        >>> renamed_imgs = imgs.rename(None)
+        >>> renamed_imgs.names
+        (None, None, None, None)
+
+        >>> renamed_imgs = imgs.rename('batch', 'channel', 'height', 'width')
+        >>> renamed_imgs.names
+        ('batch', 'channel', 'height', 'width')
+
+    .. warning::
+        The named tensor API is experimental and subject to change.
+
+    """
+    if has_torch_function_unary(self):
+        return handle_torch_function(
+            Tensor.rename, (self,), self, *names, **rename_map
+        )
+
+    # See Note [rename_ / rename API]
+    return update_names(self, names, rename_map, inplace=False)
+
+

Renames dimension names of :attr:self.

+

There are two main usages:

+

self.rename(**rename_map) returns a view on tensor that has dims +renamed as specified in the mapping :attr:rename_map.

+

self.rename(*names) returns a view on tensor, renaming all +dimensions positionally using :attr:names. +Use self.rename(None) to drop names on a tensor.

+

One cannot specify both positional args :attr:names and keyword args +:attr:rename_map.

+

Examples::

+
>>> imgs = torch.rand(2, 3, 5, 7, names=('N', 'C', 'H', 'W'))
+>>> renamed_imgs = imgs.rename(N='batch', C='channels')
+>>> renamed_imgs.names
+('batch', 'channels', 'H', 'W')
+
+>>> renamed_imgs = imgs.rename(None)
+>>> renamed_imgs.names
+(None, None, None, None)
+
+>>> renamed_imgs = imgs.rename('batch', 'channel', 'height', 'width')
+>>> renamed_imgs.names
+('batch', 'channel', 'height', 'width')
+
+
+

Warning

+

The named tensor API is experimental and subject to change.

+
+
+
+def rename_(self, *names, **rename_map) +
+
+
+ +Expand source code + +
def rename_(self, *names, **rename_map):
+    """In-place version of :meth:`~Tensor.rename`."""
+
+    if has_torch_function_unary(self):
+        return handle_torch_function(
+            Tensor.rename_, (self,), self, *names, **rename_map
+        )
+
+    # Note [rename_ / rename API]
+    # The Python API for these is different from the C++ API. In Python:
+    # 1) tensor.rename(*names) takes a vararglist of names
+    # 2) tensor.rename(**rename_map) takes a map of names to rename.
+    # C++ is static, making it difficult to implement similar behavior.
+    return update_names(self, names, rename_map, inplace=True)
+
+

In-place version of :meth:~Tensor.rename.

+
+
+def resize(self, *sizes) +
+
+
+ +Expand source code + +
def resize(self, *sizes):
+    if has_torch_function_unary(self):
+        return handle_torch_function(Tensor.resize, (self,), self, *sizes)
+    warnings.warn("non-inplace resize is deprecated")
+    from torch.autograd._functions import Resize
+
+    return Resize.apply(self, sizes)
+
+
+
+
+def resize_as(self, tensor) +
+
+
+ +Expand source code + +
def resize_as(self, tensor):
+    if has_torch_function_variadic(self, tensor):
+        return handle_torch_function(Tensor.resize_as, (self, tensor), self, tensor)
+    warnings.warn("non-inplace resize_as is deprecated")
+    from torch.autograd._functions import Resize
+
+    return Resize.apply(self, tensor.size())
+
+
+
+
+def share_memory_(self) +
+
+
+ +Expand source code + +
def share_memory_(self):
+    r"""Moves the underlying storage to shared memory.
+
+    This is a no-op if the underlying storage is already in shared memory
+    and for CUDA tensors. Tensors in shared memory cannot be resized.
+
+    See :meth:`torch.UntypedStorage.share_memory_` for more details.
+    """
+    if has_torch_function_unary(self):
+        return handle_torch_function(Tensor.share_memory_, (self,), self)
+    self._typed_storage()._share_memory_()
+    return self
+
+

Moves the underlying storage to shared memory.

+

This is a no-op if the underlying storage is already in shared memory +and for CUDA tensors. Tensors in shared memory cannot be resized.

+

See :meth:torch.UntypedStorage.share_memory_ for more details.

+
+
+def solve(self, other) +
+
+
+ +Expand source code + +
def solve(self, other):
+    from torch._linalg_utils import solve
+
+    return solve(self, other)
+
+
+
+
+def split(self, split_size, dim=0) +
+
+
+ +Expand source code + +
def split(self, split_size, dim=0):
+    r"""See :func:`torch.split`"""
+    if has_torch_function_unary(self):
+        return handle_torch_function(
+            Tensor.split, (self,), self, split_size, dim=dim
+        )
+    if isinstance(split_size, Tensor):
+        try:
+            split_size = int(split_size)
+        except ValueError:
+            pass
+
+    if isinstance(split_size, (int, torch.SymInt)):
+        return torch._VF.split(self, split_size, dim)  # type: ignore[attr-defined]
+    else:
+        return torch._VF.split_with_sizes(self, split_size, dim)
+
+

See :func:torch.split

+
+
+def stft(self,
n_fft: int,
hop_length: int | None = None,
win_length: int | None = None,
window: Optional[Tensor] = None,
center: bool = True,
pad_mode: str = 'reflect',
normalized: bool = False,
onesided: bool | None = None,
return_complex: bool | None = None)
+
+
+
+ +Expand source code + +
def stft(
+    self,
+    n_fft: int,
+    hop_length: Optional[int] = None,
+    win_length: Optional[int] = None,
+    window: "Optional[Tensor]" = None,
+    center: bool = True,
+    pad_mode: str = "reflect",
+    normalized: bool = False,
+    onesided: Optional[bool] = None,
+    return_complex: Optional[bool] = None,
+):
+    r"""See :func:`torch.stft`
+
+    .. warning::
+      This function changed signature at version 0.4.1. Calling with
+      the previous signature may cause error or return incorrect result.
+    """
+    if has_torch_function_unary(self):
+        return handle_torch_function(
+            Tensor.stft,
+            (self,),
+            self,
+            n_fft,
+            hop_length=hop_length,
+            win_length=win_length,
+            window=window,
+            center=center,
+            pad_mode=pad_mode,
+            normalized=normalized,
+            onesided=onesided,
+            return_complex=return_complex,
+        )
+    return torch.stft(
+        self,
+        n_fft,
+        hop_length,
+        win_length,
+        window,
+        center,
+        pad_mode,
+        normalized,
+        onesided,
+        return_complex=return_complex,
+    )
+
+

See :func:torch.stft

+
+

Warning

+

This function changed signature at version 0.4.1. Calling with +the previous signature may cause error or return incorrect result.

+
+
+
+def storage(self) +
+
+
+ +Expand source code + +
def storage(self):
+    r"""
+    storage() -> torch.TypedStorage
+
+    Returns the underlying :class:`TypedStorage`.
+
+    .. warning::
+
+        :class:`TypedStorage` is deprecated. It will be removed in the future, and
+        :class:`UntypedStorage` will be the only storage class. To access the
+        :class:`UntypedStorage` directly, use :attr:`Tensor.untyped_storage()`.
+    """
+    if has_torch_function_unary(self):
+        return handle_torch_function(Tensor.storage, (self,), self)
+
+    torch.storage._warn_typed_storage_removal(stacklevel=2)
+    return self._typed_storage()
+
+

storage() -> torch.TypedStorage

+

Returns the underlying :class:TypedStorage.

+
+

Warning

+

:class:TypedStorage is deprecated. It will be removed in the future, and +:class:UntypedStorage will be the only storage class. To access the +:class:UntypedStorage directly, use :attr:Tensor.untyped_storage().

+
+
+
+def storage_type(self) +
+
+
+ +Expand source code + +
def storage_type(self):
+    r"""storage_type() -> type
+
+    Returns the type of the underlying storage.
+
+    """
+    if has_torch_function_unary(self):
+        return handle_torch_function(Tensor.storage_type, (self,), self)
+
+    torch.storage._warn_typed_storage_removal()
+
+    return self._typed_storage()._get_legacy_storage_class()
+
+

storage_type() -> type

+

Returns the type of the underlying storage.

+
+
+def symeig(self, eigenvectors=False) +
+
+
+ +Expand source code + +
def symeig(self, eigenvectors=False):
+    from torch._linalg_utils import _symeig
+
+    return _symeig(self, eigenvectors=eigenvectors)
+
+
+
+
+def to_sparse_coo(self) +
+
+
+ +Expand source code + +
def to_sparse_coo(self):
+    """Convert a tensor to :ref:`coordinate format <sparse-coo-docs>`.
+
+    Examples::
+
+         >>> dense = torch.randn(5, 5)
+         >>> sparse = dense.to_sparse_coo()
+         >>> sparse._nnz()
+         25
+
+    """
+    return self.to_sparse()
+
+

Convert a tensor to :ref:coordinate format <sparse-coo-docs>.

+

Examples::

+
 >>> dense = torch.randn(5, 5)
+ >>> sparse = dense.to_sparse_coo()
+ >>> sparse._nnz()
+ 25
+
+
+
+def unflatten(self, dim, sizes) +
+
+
+ +Expand source code + +
def unflatten(self, dim, sizes):
+    r"""
+    unflatten(dim, sizes) -> Tensor
+
+    See :func:`torch.unflatten`.
+
+    """
+    if has_torch_function_unary(self):
+        return handle_torch_function(Tensor.unflatten, (self,), self, dim, sizes)
+
+    if not sizes:
+        raise RuntimeError("unflatten: sizes must be non-empty")
+
+    names = None
+    if isinstance(sizes, OrderedDict) or (
+        isinstance(sizes, (tuple, list)) and isinstance(sizes[0], (tuple, list))
+    ):
+        names, sizes = unzip_namedshape(sizes)
+        return super().unflatten(dim, sizes, names)
+    else:
+        return super().unflatten(dim, sizes)
+
+

unflatten(dim, sizes) -> Tensor

+

See :func:torch.unflatten.

+
+
+def unique(self, sorted=True, return_inverse=False, return_counts=False, dim=None) +
+
+
+ +Expand source code + +
def unique(self, sorted=True, return_inverse=False, return_counts=False, dim=None):
+    r"""Returns the unique elements of the input tensor.
+
+    See :func:`torch.unique`
+    """
+    if has_torch_function_unary(self):
+        return handle_torch_function(
+            Tensor.unique,
+            (self,),
+            self,
+            sorted=sorted,
+            return_inverse=return_inverse,
+            return_counts=return_counts,
+            dim=dim,
+        )
+    return torch.unique(
+        self,
+        sorted=sorted,
+        return_inverse=return_inverse,
+        return_counts=return_counts,
+        dim=dim,
+    )
+
+

Returns the unique elements of the input tensor.

+

See :func:torch.unique

+
+
+def unique_consecutive(self, return_inverse=False, return_counts=False, dim=None) +
+
+
+ +Expand source code + +
def unique_consecutive(self, return_inverse=False, return_counts=False, dim=None):
+    r"""Eliminates all but the first element from every consecutive group of equivalent elements.
+
+    See :func:`torch.unique_consecutive`
+    """
+    if has_torch_function_unary(self):
+        return handle_torch_function(
+            Tensor.unique_consecutive,
+            (self,),
+            self,
+            return_inverse=return_inverse,
+            return_counts=return_counts,
+            dim=dim,
+        )
+    return torch.unique_consecutive(
+        self, return_inverse=return_inverse, return_counts=return_counts, dim=dim
+    )
+
+

Eliminates all but the first element from every consecutive group of equivalent elements.

+

See :func:torch.unique_consecutive

+
+
+
+
+class SpacyToken +(...) +
+
+

An individual token – i.e. a word, punctuation symbol, whitespace, +etc.

+

DOCS: https://spacy.io/api/token

+

Instance variables

+
+
var ancestors
+
+

A sequence of this token's syntactic ancestors.

+

YIELDS (Token): A sequence of ancestor tokens such that +ancestor.is_ancestor(self).

+

DOCS: https://spacy.io/api/token#ancestors

+
+
var children
+
+

A sequence of the token's immediate syntactic children.

+

YIELDS (Token): A child token such that child.head==self.

+

DOCS: https://spacy.io/api/token#children

+
+
var cluster
+
+

RETURNS (int): Brown cluster ID.

+
+
var conjuncts
+
+

A sequence of coordinated tokens, including the token itself.

+

RETURNS (tuple): The coordinated tokens.

+

DOCS: https://spacy.io/api/token#conjuncts

+
+
var dep
+
+

RETURNS (uint64): ID of syntactic dependency label.

+
+
var dep_
+
+

RETURNS (str): The syntactic dependency label.

+
+
var doc
+
+
+
+
var ent_id
+
+

RETURNS (uint64): ID of the entity the token is an instance of, +if any.

+
+
var ent_id_
+
+

RETURNS (str): ID of the entity the token is an instance of, +if any.

+
+
var ent_iob
+
+

IOB code of named entity tag. 1="I", 2="O", 3="B". 0 means no tag +is assigned.

+

RETURNS (uint64): IOB code of named entity tag.

+
+
var ent_iob_
+
+

IOB code of named entity tag. "B" means the token begins an entity, +"I" means it is inside an entity, "O" means it is outside an entity, +and "" means no entity tag is set. "B" with an empty ent_type +means that the token is blocked from further processing by NER.

+

RETURNS (str): IOB code of named entity tag.

+
+
var ent_kb_id
+
+

RETURNS (uint64): Named entity KB ID.

+
+
var ent_kb_id_
+
+

RETURNS (str): Named entity KB ID.

+
+
var ent_type
+
+

RETURNS (uint64): Named entity type.

+
+
var ent_type_
+
+

RETURNS (str): Named entity type.

+
+
var has_vector
+
+

A boolean value indicating whether a word vector is associated with +the object.

+

RETURNS (bool): Whether a word vector is associated with the object.

+

DOCS: https://spacy.io/api/token#has_vector

+
+
var head
+
+

The syntactic parent, or "governor", of this token. +If token.has_head() is False, this method will return itself.

+

RETURNS (Token): The token predicted by the parser to be the head of +the current token.

+
+
var i
+
+
+
+
var idx
+
+

RETURNS (int): The character offset of the token within the parent +document.

+
+
var is_alpha
+
+

RETURNS (bool): Whether the token consists of alpha characters. +Equivalent to token.text.isalpha().

+
+
var is_ascii
+
+

RETURNS (bool): Whether the token consists of ASCII characters. +Equivalent to [any(ord(c) >= 128 for c in token.text)].

+
+
var is_bracket
+
+

RETURNS (bool): Whether the token is a bracket.

+
+
var is_currency
+
+

RETURNS (bool): Whether the token is a currency symbol.

+
+
var is_digit
+
+

RETURNS (bool): Whether the token consists of digits. Equivalent to +token.text.isdigit().

+
+
var is_left_punct
+
+

RETURNS (bool): Whether the token is a left punctuation mark.

+
+
var is_lower
+
+

RETURNS (bool): Whether the token is in lowercase. Equivalent to +token.text.islower().

+
+
var is_oov
+
+

RETURNS (bool): Whether the token is out-of-vocabulary.

+
+
var is_punct
+
+

RETURNS (bool): Whether the token is punctuation.

+
+
var is_quote
+
+

RETURNS (bool): Whether the token is a quotation mark.

+
+
var is_right_punct
+
+

RETURNS (bool): Whether the token is a right punctuation mark.

+
+
var is_sent_end
+
+

A boolean value indicating whether the token ends a sentence. +None if unknown. Defaults to True for the last token in the Doc.

+

RETURNS (bool / None): Whether the token ends a sentence. +None if unknown.

+

DOCS: https://spacy.io/api/token#is_sent_end

+
+
var is_sent_start
+
+

A boolean value indicating whether the token starts a sentence. +None if unknown. Defaults to True for the first token in the Doc.

+

RETURNS (bool / None): Whether the token starts a sentence. +None if unknown.

+
+
var is_space
+
+

RETURNS (bool): Whether the token consists of whitespace characters. +Equivalent to token.text.isspace().

+
+
var is_stop
+
+

RETURNS (bool): Whether the token is a stop word, i.e. part of a +"stop list" defined by the language data.

+
+
var is_title
+
+

RETURNS (bool): Whether the token is in titlecase. Equivalent to +token.text.istitle().

+
+
var is_upper
+
+

RETURNS (bool): Whether the token is in uppercase. Equivalent to +token.text.isupper()

+
+
var lang
+
+

RETURNS (uint64): ID of the language of the parent document's +vocabulary.

+
+
var lang_
+
+

RETURNS (str): Language of the parent document's vocabulary, +e.g. 'en'.

+
+
var left_edge
+
+

The leftmost token of this token's syntactic descendents.

+

RETURNS (Token): The first token such that self.is_ancestor(token).

+
+
var lefts
+
+

The leftward immediate children of the word, in the syntactic +dependency parse.

+

YIELDS (Token): A left-child of the token.

+

DOCS: https://spacy.io/api/token#lefts

+
+
var lemma
+
+

RETURNS (uint64): ID of the base form of the word, with no +inflectional suffixes.

+
+
var lemma_
+
+

RETURNS (str): The token lemma, i.e. the base form of the word, +with no inflectional suffixes.

+
+
var lex
+
+

RETURNS (Lexeme): The underlying lexeme.

+
+
var lex_id
+
+

RETURNS (int): Sequential ID of the token's lexical type.

+
+
var like_email
+
+

RETURNS (bool): Whether the token resembles an email address.

+
+
var like_num
+
+

RETURNS (bool): Whether the token resembles a number, e.g. "10.9", +"10", "ten", etc.

+
+
var like_url
+
+

RETURNS (bool): Whether the token resembles a URL.

+
+
var lower
+
+

RETURNS (uint64): ID of the lowercase token text.

+
+
var lower_
+
+

RETURNS (str): The lowercase token text. Equivalent to +Token.text.lower().

+
+
var morph
+
+
+
+
var n_lefts
+
+

The number of leftward immediate children of the word, in the +syntactic dependency parse.

+

RETURNS (int): The number of leftward immediate children of the +word, in the syntactic dependency parse.

+

DOCS: https://spacy.io/api/token#n_lefts

+
+
var n_rights
+
+

The number of rightward immediate children of the word, in the +syntactic dependency parse.

+

RETURNS (int): The number of rightward immediate children of the +word, in the syntactic dependency parse.

+

DOCS: https://spacy.io/api/token#n_rights

+
+
var norm
+
+

RETURNS (uint64): ID of the token's norm, i.e. a normalised form of +the token text. Usually set in the language's tokenizer exceptions +or norm exceptions.

+
+
var norm_
+
+

RETURNS (str): The token's norm, i.e. a normalised form of the +token text. Usually set in the language's tokenizer exceptions or +norm exceptions.

+
+
var orth
+
+

RETURNS (uint64): ID of the verbatim text content.

+
+
var orth_
+
+

RETURNS (str): Verbatim text content (identical to +Token.text). Exists mostly for consistency with the other +attributes.

+
+
var pos
+
+

RETURNS (uint64): ID of coarse-grained part-of-speech tag.

+
+
var pos_
+
+

RETURNS (str): Coarse-grained part-of-speech tag.

+
+
var prefix
+
+

RETURNS (uint64): ID of a length-N substring from the start of the +token. Defaults to N=1.

+
+
var prefix_
+
+

RETURNS (str): A length-N substring from the start of the token. +Defaults to N=1.

+
+
var prob
+
+

RETURNS (float): Smoothed log probability estimate of token type.

+
+
var rank
+
+

RETURNS (int): Sequential ID of the token's lexical type, used to +index into tables, e.g. for word vectors.

+
+
var right_edge
+
+

The rightmost token of this token's syntactic descendents.

+

RETURNS (Token): The last token such that self.is_ancestor(token).

+
+
var rights
+
+

The rightward immediate children of the word, in the syntactic +dependency parse.

+

YIELDS (Token): A right-child of the token.

+

DOCS: https://spacy.io/api/token#rights

+
+
var sent
+
+

RETURNS (Span): The sentence span that the token is a part of.

+
+
var sent_start
+
+

Deprecated: use Token.is_sent_start instead.

+
+
var sentiment
+
+

RETURNS (float): A scalar value indicating the positivity or +negativity of the token.

+
+
var shape
+
+

RETURNS (uint64): ID of the token's shape, a transform of the +token's string, to show orthographic features (e.g. "Xxxx", "dd").

+
+
var shape_
+
+

RETURNS (str): Transform of the token's string, to show +orthographic features. For example, "Xxxx" or "dd".

+
+
var subtree
+
+

A sequence containing the token and all the token's syntactic +descendants.

+

YIELDS (Token): A descendent token such that +self.is_ancestor(descendent) or token == self.

+

DOCS: https://spacy.io/api/token#subtree

+
+
var suffix
+
+

RETURNS (uint64): ID of a length-N substring from the end of the +token. Defaults to N=3.

+
+
var suffix_
+
+

RETURNS (str): A length-N substring from the end of the token. +Defaults to N=3.

+
+
var tag
+
+

RETURNS (uint64): ID of fine-grained part-of-speech tag.

+
+
var tag_
+
+

RETURNS (str): Fine-grained part-of-speech tag.

+
+
var tensor
+
+
+
+
var text
+
+

RETURNS (str): The original verbatim text of the token.

+
+
var text_with_ws
+
+

RETURNS (str): The text content of the span (with trailing +whitespace).

+
+
var vector
+
+

A real-valued meaning representation.

+

RETURNS (numpy.ndarray[ndim=1, dtype='float32']): A 1D numpy array +representing the token's semantics.

+

DOCS: https://spacy.io/api/token#vector

+
+
var vector_norm
+
+

The L2 norm of the token's vector representation.

+

RETURNS (float): The L2 norm of the vector representation.

+

DOCS: https://spacy.io/api/token#vector_norm

+
+
var vocab
+
+
+
+
var whitespace_
+
+

RETURNS (str): The trailing whitespace character, if present.

+
+
+

Methods

+
+
+def check_flag(...) +
+
+

Token.check_flag(self, attr_id_t flag_id) -> bool +Check the value of a boolean flag.

+
    flag_id (int): The ID of the flag attribute.
+    RETURNS (bool): Whether the flag is set.
+
+    DOCS: <https://spacy.io/api/token#check_flag>
+
+
+
+def get_extension(...) +
+
+

Token.get_extension(type cls, name) +Look up a previously registered extension by name.

+
    name (str): Name of the extension.
+    RETURNS (tuple): A <code>(default, method, getter, setter)</code> tuple.
+
+    DOCS: <https://spacy.io/api/token#get_extension>
+
+
+
+def has_dep(...) +
+
+

Token.has_dep(self) +Check whether the token has annotated dep information. +Returns False when the dep label is unset/missing.

+
    RETURNS (bool): Whether the dep label is valid or not.
+
+
+
+def has_extension(...) +
+
+

Token.has_extension(type cls, name) +Check whether an extension has been registered.

+
    name (str): Name of the extension.
+    RETURNS (bool): Whether the extension has been registered.
+
+    DOCS: <https://spacy.io/api/token#has_extension>
+
+
+
+def has_head(...) +
+
+

Token.has_head(self) +Check whether the token has annotated head information. +Return False when the head annotation is unset/missing.

+
    RETURNS (bool): Whether the head annotation is valid or not.
+
+
+
+def has_morph(...) +
+
+

Token.has_morph(self) +Check whether the token has annotated morph information. +Return False when the morph annotation is unset/missing.

+
    RETURNS (bool): Whether the morph annotation is set.
+
+
+
+def iob_strings(...) +
+
+

Token.iob_strings(type cls)

+
+
+def is_ancestor(...) +
+
+

Token.is_ancestor(self, descendant) +Check whether this token is a parent, grandparent, etc. of another +in the dependency tree.

+
    descendant (Token): Another token.
+    RETURNS (bool): Whether this token is the ancestor of the descendant.
+
+    DOCS: <https://spacy.io/api/token#is_ancestor>
+
+
+
+def nbor(...) +
+
+

Token.nbor(self, int i=1) +Get a neighboring token.

+
    i (int): The relative position of the token to get. Defaults to 1.
+    RETURNS (Token): The token at position `self.doc[self.i+i]`.
+
+    DOCS: <https://spacy.io/api/token#nbor>
+
+
+
+def remove_extension(...) +
+
+

Token.remove_extension(type cls, name) +Remove a previously registered extension.

+
    name (str): Name of the extension.
+    RETURNS (tuple): A <code>(default, method, getter, setter)</code> tuple of the
+        removed extension.
+
+    DOCS: <https://spacy.io/api/token#remove_extension>
+
+
+
+def set_extension(...) +
+
+

Token.set_extension(type cls, name, **kwargs) +Define a custom attribute which becomes available as Token._.

+
    name (str): Name of the attribute to set.
+    default: Optional default value of the attribute.
+    getter (callable): Optional getter function.
+    setter (callable): Optional setter function.
+    method (callable): Optional method for method extension.
+    force (bool): Force overwriting existing attribute.
+
+    DOCS: <https://spacy.io/api/token#set_extension>
+    USAGE: <https://spacy.io/usage/processing-pipelines#custom-components-attributes>
+
+
+
+def set_morph(...) +
+
+

Token.set_morph(self, features)

+
+
+def similarity(...) +
+
+

Token.similarity(self, other) +Make a semantic similarity estimate. The default estimate is cosine +similarity using an average of word vectors.

+
    other (object): The object to compare with. By default, accepts <code><a title="lang_main.types.Doc" href="#lang_main.types.Doc">Doc</a></code>,
+        <code>Span</code>, <code><a title="lang_main.types.Token" href="#lang_main.types.Token">Token</a></code> and <code>Lexeme</code> objects.
+    RETURNS (float): A scalar similarity score. Higher is more similar.
+
+    DOCS: <https://spacy.io/api/token#similarity>
+
+
+
+
+
+
+
+ +
+ + + diff --git a/src/lang_main/analysis/shared.py b/src/lang_main/analysis/shared.py index a90df48..04cb674 100644 --- a/src/lang_main/analysis/shared.py +++ b/src/lang_main/analysis/shared.py @@ -5,9 +5,6 @@ from typing import cast import networkx as nx import numpy as np import numpy.typing as npt - -# import sentence_transformers # TODO check removal -# import sentence_transformers.util # TODO check removal from networkx import Graph from pandas import DataFrame, Series from sentence_transformers import SentenceTransformer diff --git a/src/lang_main/analysis/timeline.py b/src/lang_main/analysis/timeline.py index 0c8c0cd..5819678 100644 --- a/src/lang_main/analysis/timeline.py +++ b/src/lang_main/analysis/timeline.py @@ -47,7 +47,7 @@ def _non_relevant_obj_ids( feats_per_obj_id = feats_per_obj_id.dropna() unique_feats_per_obj_id = len(feats_per_obj_id.unique()) - if unique_feats_per_obj_id > thresh_unique_feat_per_id: + if unique_feats_per_obj_id >= thresh_unique_feat_per_id: ids_to_ignore.add(obj_id) return tuple(ids_to_ignore) diff --git a/src/lang_main/model_loader.py b/src/lang_main/model_loader.py index a00c3d2..0ac98c0 100644 --- a/src/lang_main/model_loader.py +++ b/src/lang_main/model_loader.py @@ -119,7 +119,7 @@ def _preprocess_STFR_model_name( raise FileNotFoundError( f'Target model >{model_name}< not found under {model_path}' ) - model_name_or_path = str(model_path) + model_name_or_path = str(model_path) # pragma: no cover else: model_name_or_path = model_name diff --git a/src/lang_main/pipelines/predefined.py b/src/lang_main/pipelines/predefined.py index 8a5e6d0..4a399f3 100644 --- a/src/lang_main/pipelines/predefined.py +++ b/src/lang_main/pipelines/predefined.py @@ -30,11 +30,12 @@ from lang_main.constants import ( DATE_COLS, FEATURE_NAME_OBJ_ID, FEATURE_NAME_OBJ_TEXT, + MAX_EDGE_NUMBER, MODEL_INPUT_FEATURES, NAME_DELTA_FEAT_TO_REPAIR, SAVE_PATH_FOLDER, + TARGET_FEATURE, THRESHOLD_AMOUNT_CHARACTERS, - THRESHOLD_EDGE_NUMBER, THRESHOLD_NUM_ACTIVITIES, THRESHOLD_SIMILARITY, THRESHOLD_TIMELINE_SIMILARITY, @@ -72,7 +73,7 @@ def build_base_target_feature_pipe() -> Pipeline: pipe_target_feat.add( entry_wise_cleansing, { - 'target_features': ('VorgangsBeschreibung',), + 'target_features': (TARGET_FEATURE,), 'cleansing_func': clean_string_slim, }, save_result=True, @@ -81,7 +82,7 @@ def build_base_target_feature_pipe() -> Pipeline: pipe_target_feat.add( analyse_feature, { - 'target_feature': 'VorgangsBeschreibung', + 'target_feature': TARGET_FEATURE, }, save_result=True, ) @@ -140,7 +141,7 @@ def build_tk_graph_post_pipe() -> Pipeline: pipe_graph_postprocessing.add( graphs.filter_graph_by_number_edges, { - 'limit': THRESHOLD_EDGE_NUMBER, + 'limit': MAX_EDGE_NUMBER, 'property': 'weight', }, ) diff --git a/tests/analysis/test_graphs.py b/tests/analysis/test_graphs.py index 1145b46..929c679 100644 --- a/tests/analysis/test_graphs.py +++ b/tests/analysis/test_graphs.py @@ -321,7 +321,7 @@ def test_pipe_add_graph_metrics(): def test_pipe_rescale_graph_edge_weights(tk_graph): rescaled_tkg, rescaled_undir = graphs.pipe_rescale_graph_edge_weights(tk_graph) assert rescaled_tkg[2][1]['weight'] == pytest.approx(1.0) - assert rescaled_tkg[1][2]['weight'] == pytest.approx(0.0952) + assert rescaled_tkg[1][2]['weight'] == pytest.approx(0.095238) assert rescaled_undir[2][1]['weight'] == pytest.approx(1.0) assert rescaled_undir[1][2]['weight'] == pytest.approx(1.0) @@ -331,7 +331,7 @@ def test_rescale_edge_weights(import_graph, request): test_graph = request.getfixturevalue(import_graph) rescaled_graph = graphs.rescale_edge_weights(test_graph) assert rescaled_graph[2][1]['weight'] == pytest.approx(1.0) - assert rescaled_graph[1][2]['weight'] == pytest.approx(0.0952) + assert rescaled_graph[1][2]['weight'] == pytest.approx(0.095238) @pytest.mark.parametrize('import_graph', ['graph', 'tk_graph']) diff --git a/tests/analysis/test_timeline.py b/tests/analysis/test_timeline.py index 374b882..1ab6bbc 100644 --- a/tests/analysis/test_timeline.py +++ b/tests/analysis/test_timeline.py @@ -72,7 +72,7 @@ def test_calc_delta_to_repair(data_pre_cleaned, convert_to_days): def test_non_relevant_obj_ids(data_pre_cleaned): feature_uniqueness = 'HObjektText' feature_obj_id = 'ObjektID' - threshold = 1 + threshold = 2 data = data_pre_cleaned.copy() data.at[0, feature_obj_id] = 1 ids_to_ignore = tl._non_relevant_obj_ids( @@ -88,7 +88,7 @@ def test_non_relevant_obj_ids(data_pre_cleaned): def test_remove_non_relevant_obj_ids(data_pre_cleaned): feature_uniqueness = 'HObjektText' feature_obj_id = 'ObjektID' - threshold = 1 + threshold = 2 data = data_pre_cleaned.copy() data.at[0, feature_obj_id] = 1 diff --git a/tests/test_model_loader.py b/tests/test_model_loader.py index 1127383..09179bc 100644 --- a/tests/test_model_loader.py +++ b/tests/test_model_loader.py @@ -25,8 +25,6 @@ from lang_main.types import LanguageModels @pytest.mark.parametrize( 'model_name', [ - STFRModelTypes.ALL_DISTILROBERTA_V1, - STFRModelTypes.ALL_MINI_LM_L12_V2, STFRModelTypes.ALL_MINI_LM_L6_V2, STFRModelTypes.ALL_MPNET_BASE_V2, ], @@ -47,6 +45,25 @@ def test_load_sentence_transformer( assert isinstance(model, SentenceTransformer) +def test_preprocess_STFR_model_name() -> None: + model_name_not_exist = 'TestModel' + ret_model_name = model_loader._preprocess_STFR_model_name( + model_name=model_name_not_exist, backend=STFRBackends.TORCH, force_download=True + ) + assert ret_model_name == model_name_not_exist + ret_model_name = model_loader._preprocess_STFR_model_name( + model_name=model_name_not_exist, backend=STFRBackends.TORCH, force_download=False + ) + assert ret_model_name == model_name_not_exist + + model_name_exist = STFRModelTypes.E5_BASE_STS_EN_DE + backend_exist = STFRBackends.ONNX + with pytest.raises(FileNotFoundError): + _ = model_loader._preprocess_STFR_model_name( + model_name=model_name_exist, backend=backend_exist, force_download=False + ) + + @pytest.mark.parametrize( 'similarity_func', [ @@ -57,8 +74,6 @@ def test_load_sentence_transformer( @pytest.mark.parametrize( 'model_name', [ - STFRModelTypes.ALL_DISTILROBERTA_V1, - STFRModelTypes.ALL_MINI_LM_L12_V2, STFRModelTypes.ALL_MINI_LM_L6_V2, STFRModelTypes.ALL_MPNET_BASE_V2, ], @@ -108,6 +123,14 @@ def test_instantiate_spacy_model(): assert isinstance(model, Language) +def test_fail_instantiate_spacy_model(): + with pytest.raises(KeyError): + _ = model_loader.instantiate_model( + model_load_map=model_loader.MODEL_LOADER_MAP, + model='test', # type: ignore + ) # type: ignore + + @pytest.mark.mload def test_instantiate_stfr_model(): model = model_loader.instantiate_model(