added new graph metrics

This commit is contained in:
Florian Förster 2024-12-19 16:26:01 +01:00
parent 123869e203
commit 80a35c4658
24 changed files with 826 additions and 97 deletions

View File

@ -1,27 +1,24 @@
# lang_main: Config file # lang_main: Config file
[info]
pkg = 'lang_main'
[paths] [paths]
inputs = './inputs/' inputs = './data/'
# results = './results/dummy_N_1000/' # results = './results/dummy_N_1000/'
# dataset = '../data/Dummy_Dataset_N_1000.csv' # dataset = '../data/Dummy_Dataset_N_1000.csv'
results = './results/test_20240807/' results = './data/'
dataset = '../data/02_202307/Export4.csv' models = '../lang-models'
[logging] [logging]
enabled = true enabled = true
stderr = true stderr = true
file = true file = true
# only debugging features, production-ready pipelines should always # control which pipelines are executed
# be fully executed
[control] [control]
preprocessing_skip = true preprocessing_skip = false
token_analysis_skip = false token_analysis_skip = false
graph_postprocessing_skip = false graph_postprocessing_skip = false
graph_rescaling_skip = false graph_rescaling_skip = false
graph_static_rendering_skip = false graph_static_rendering_skip = true
time_analysis_skip = true time_analysis_skip = true
[preprocess] [preprocess]

View File

@ -155,7 +155,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 72, "execution_count": 1,
"id": "0a48d11d-1f2b-475e-9ddf-bb9a3f67accb", "id": "0a48d11d-1f2b-475e-9ddf-bb9a3f67accb",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -165,53 +165,471 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 73, "execution_count": 2,
"id": "e340377a-0df4-44ca-b18e-8b354e273eb9", "id": "e340377a-0df4-44ca-b18e-8b354e273eb9",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"save_pth = Path.cwd() / 'test.graphml'" "save_pth = Path.cwd() / 'tk_graph_built.pkl'\n",
"pth_export = Path.cwd() / 'tk_graph_built'\n",
"assert save_pth.exists()"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 82, "execution_count": 3,
"id": "66677ad0-a1e5-4772-a0ba-7fbeeda55297", "id": "8aba87b2-e924-4748-98c6-05c4676f3c08",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Looking iteratively for config file. Start: A:\\Arbeitsaufgaben\\lang-main\\src\\lang_main, stop folder: src\n",
"Loaded TOML config file successfully.\n",
"Loaded config from: >>A:\\Arbeitsaufgaben\\lang-main\\lang_main_config.toml<<\n",
"Library path is: A:\\Arbeitsaufgaben\\lang-main\n",
"Root path is: A:\\Arbeitsaufgaben\n"
]
}
],
"source": [ "source": [
"nx.write_graphml(G, save_pth)" "from lang_main import io"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 84, "execution_count": 5,
"id": "f01ebe25-56b9-410a-a2bf-d5a6e211de7a", "id": "728fdce3-cfe0-4c4b-bcbf-c3d61547e94f",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-12-19 12:24:12 +0000 | lang_main:io:INFO | Loaded file successfully.\n"
]
}
],
"source": [ "source": [
"G_load = nx.read_graphml(save_pth, node_type=int)" "G = io.load_pickle(save_pth)"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 85, "execution_count": 6,
"id": "10bfad35-1f96-41a1-9014-578313502e6c", "id": "fcb74134-4192-4d68-a535-f5a502d02b67",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"OutEdgeView([(1, 2), (1, 3), (1, 4), (2, 4), (2, 1), (3, 4)])" "<networkx.classes.graph.Graph at 0x222d01224d0>"
] ]
}, },
"execution_count": 85, "execution_count": 6,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
], ],
"source": [ "source": [
"G_load.edges" "G.undirected"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "024560e5-373a-46f4-b0f7-7794535daa78",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 9,
"id": "65745c09-e834-47c9-b761-aabd5fa03e57",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-12-19 11:41:58 +0000 | lang_main:graphs:INFO | Successfully saved graph as GraphML file under A:\\Arbeitsaufgaben\\lang-main\\notebooks\\tk_graph_built.graphml.\n"
]
}
],
"source": [
"G.to_GraphML(Path.cwd(), 'tk_graph_built')"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "8da93a34-7bb1-4b9c-851b-0793c2c483bf",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"TokenGraph(name: TokenGraph, number of nodes: 13, number of edges: 9)"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"G"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "1cc98c6e-4fa5-49ed-bb97-c62d8e103ccc",
"metadata": {},
"outputs": [],
"source": [
"G_copy = G.copy()\n",
"G_copy = G_copy.undirected"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "a5044a6f-8903-4e33-a56e-abf647019d8a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'degree_weighted': 2}\n",
"{'degree_weighted': 1}\n",
"{'degree_weighted': 4}\n",
"{'degree_weighted': 1}\n",
"{'degree_weighted': 1}\n",
"{'degree_weighted': 1}\n",
"{'degree_weighted': 2}\n",
"{'degree_weighted': 1}\n",
"{'degree_weighted': 1}\n",
"{'degree_weighted': 1}\n",
"{'degree_weighted': 1}\n",
"{'degree_weighted': 1}\n",
"{'degree_weighted': 1}\n"
]
}
],
"source": [
"for node in G_copy.nodes:\n",
" print(G_copy.nodes[node])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "720015a4-7338-4fce-ac6c-38e691e0efa4",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 76,
"id": "b21a34c8-5748-42b1-947e-76c3ba02a240",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 77,
"id": "1bcd4a31-2ea5-4123-8fc3-27017acd7259",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'weight': 1}\n",
"{'weight': 1}\n",
"{'weight': 20}\n",
"{'weight': 15}\n",
"{'weight': 10}\n",
"{'weight': 6}\n",
"{'weight': 1}\n",
"{'weight': 1}\n",
"{'weight': 1}\n"
]
}
],
"source": [
"for edge in G_copy.edges:\n",
" print(G_copy.edges[edge])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bfe5bacb-1f51-45d8-b920-ca9cba01282c",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 32,
"id": "48e14d11-0319-47a9-81be-c696f69d55b3",
"metadata": {},
"outputs": [],
"source": [
"import networkx as nx"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "479deea9-a44f-48eb-95d3-7a326a83f62c",
"metadata": {},
"outputs": [],
"source": [
"mapping = nx.betweenness_centrality(G_copy, normalized=True)"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "2299dc98-0a82-4009-8b38-790ffe3a3fd8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'Kontrolle': 0.015151515151515152,\n",
" 'Lichtschranke': 0.0,\n",
" 'Überprüfung': 0.09090909090909091,\n",
" 'Spannrolle': 0.0,\n",
" 'Druckventil': 0.0,\n",
" 'Schmiernippel': 0.0,\n",
" 'Inspektion': 0.015151515151515152,\n",
" 'Förderbänder': 0.0,\n",
" 'Reinigung': 0.0,\n",
" 'Luftfilter': 0.0,\n",
" 'Schutzabdeckung': 0.0,\n",
" 'Ölstand': 0.0,\n",
" 'Hydraulik': 0.0}"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mapping"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "4c00dfe4-29d3-45d8-9eec-6d8e1fbda910",
"metadata": {},
"outputs": [],
"source": [
"nx.set_node_attributes(G_copy, mapping, name='BC')"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "0e1119e1-700d-4ae0-bb36-6d4a8fe57698",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'degree_weighted': 2, 'BC': 0.015151515151515152}\n",
"{'degree_weighted': 1, 'BC': 0.0}\n",
"{'degree_weighted': 4, 'BC': 0.09090909090909091}\n",
"{'degree_weighted': 1, 'BC': 0.0}\n",
"{'degree_weighted': 1, 'BC': 0.0}\n",
"{'degree_weighted': 1, 'BC': 0.0}\n",
"{'degree_weighted': 2, 'BC': 0.015151515151515152}\n",
"{'degree_weighted': 1, 'BC': 0.0}\n",
"{'degree_weighted': 1, 'BC': 0.0}\n",
"{'degree_weighted': 1, 'BC': 0.0}\n",
"{'degree_weighted': 1, 'BC': 0.0}\n",
"{'degree_weighted': 1, 'BC': 0.0}\n",
"{'degree_weighted': 1, 'BC': 0.0}\n"
]
}
],
"source": [
"nodes_prop_mapping = {}\n",
"for node in G_copy.nodes:\n",
" node_data = G_copy.nodes[node]\n",
" prio = node_data['degree_weighted'] * node_data['BC']\n",
" nodes_prop_mapping[node] = prio\n",
" print(G_copy.nodes[node])"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "2970267e-5c9e-472f-857b-382e2406f34d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'Kontrolle': 0.030303030303030304,\n",
" 'Lichtschranke': 0.0,\n",
" 'Überprüfung': 0.36363636363636365,\n",
" 'Spannrolle': 0.0,\n",
" 'Druckventil': 0.0,\n",
" 'Schmiernippel': 0.0,\n",
" 'Inspektion': 0.030303030303030304,\n",
" 'Förderbänder': 0.0,\n",
" 'Reinigung': 0.0,\n",
" 'Luftfilter': 0.0,\n",
" 'Schutzabdeckung': 0.0,\n",
" 'Ölstand': 0.0,\n",
" 'Hydraulik': 0.0}"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"nodes_prop_mapping"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "ef2a1065-9a47-4df4-a7d9-ccfe30dbcdc7",
"metadata": {},
"outputs": [],
"source": [
"nx.set_node_attributes(G_copy, nodes_prop_mapping, name='prio')"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "2b77597a-be7a-498a-8374-f98324574619",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'degree_weighted': 2, 'BC': 0.015151515151515152, 'prio': 0.030303030303030304}\n",
"{'degree_weighted': 1, 'BC': 0.0, 'prio': 0.0}\n",
"{'degree_weighted': 4, 'BC': 0.09090909090909091, 'prio': 0.36363636363636365}\n",
"{'degree_weighted': 1, 'BC': 0.0, 'prio': 0.0}\n",
"{'degree_weighted': 1, 'BC': 0.0, 'prio': 0.0}\n",
"{'degree_weighted': 1, 'BC': 0.0, 'prio': 0.0}\n",
"{'degree_weighted': 2, 'BC': 0.015151515151515152, 'prio': 0.030303030303030304}\n",
"{'degree_weighted': 1, 'BC': 0.0, 'prio': 0.0}\n",
"{'degree_weighted': 1, 'BC': 0.0, 'prio': 0.0}\n",
"{'degree_weighted': 1, 'BC': 0.0, 'prio': 0.0}\n",
"{'degree_weighted': 1, 'BC': 0.0, 'prio': 0.0}\n",
"{'degree_weighted': 1, 'BC': 0.0, 'prio': 0.0}\n",
"{'degree_weighted': 1, 'BC': 0.0, 'prio': 0.0}\n"
]
}
],
"source": [
"for node in G_copy.nodes:\n",
" print(G_copy.nodes[node])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "089ae831-f307-473d-8a70-e72bc580ca61",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "81d9e8ad-4372-4e3b-a42b-e5b343802936",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 15,
"id": "ea0a1fc3-571f-4ce2-8561-b0a8f93e5072",
"metadata": {},
"outputs": [],
"source": [
"def build_init_graph():\n",
" edge_weights = [\n",
" {'weight': 1},\n",
" {'weight': 2},\n",
" {'weight': 3},\n",
" {'weight': 4},\n",
" {'weight': 5},\n",
" {'weight': 6},\n",
" ]\n",
" edges = [\n",
" (1, 2),\n",
" (1, 3),\n",
" (2, 4),\n",
" (3, 4),\n",
" (1, 4),\n",
" (2, 1),\n",
" ]\n",
" edges_to_add = []\n",
" for i, edge in enumerate(edges):\n",
" edge = list(edge)\n",
" edge.append(edge_weights[i]) # type: ignore\n",
" edges_to_add.append(tuple(edge))\n",
"\n",
" G = nx.DiGraph()\n",
"\n",
" G.add_edges_from(edges_to_add)\n",
"\n",
" return G"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "6e97af5f-f629-4c71-a1e4-2310ad2f3caa",
"metadata": {},
"outputs": [],
"source": [
"G_init = build_init_graph()"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "4ff8fb6b-a10b-448f-b8a5-9989f4bee81e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{1: 0.16666666666666666, 2: 0.0, 3: 0.0, 4: 0.0}"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"nx.betweenness_centrality(G_init, normalized=True)"
] ]
}, },
{ {

View File

@ -0,0 +1,73 @@
<?xml version='1.0' encoding='utf-8'?>
<graphml xmlns="http://graphml.graphdrawing.org/xmlns" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://graphml.graphdrawing.org/xmlns http://graphml.graphdrawing.org/xmlns/1.0/graphml.xsd">
<key id="d1" for="edge" attr.name="weight" attr.type="long" />
<key id="d0" for="node" attr.name="degree_weighted" attr.type="long" />
<graph edgedefault="undirected">
<node id="Kontrolle">
<data key="d0">2</data>
</node>
<node id="Lichtschranke">
<data key="d0">1</data>
</node>
<node id="Überprüfung">
<data key="d0">4</data>
</node>
<node id="Spannrolle">
<data key="d0">1</data>
</node>
<node id="Druckventil">
<data key="d0">1</data>
</node>
<node id="Schmiernippel">
<data key="d0">1</data>
</node>
<node id="Inspektion">
<data key="d0">2</data>
</node>
<node id="Förderbänder">
<data key="d0">1</data>
</node>
<node id="Reinigung">
<data key="d0">1</data>
</node>
<node id="Luftfilter">
<data key="d0">1</data>
</node>
<node id="Schutzabdeckung">
<data key="d0">1</data>
</node>
<node id="Ölstand">
<data key="d0">1</data>
</node>
<node id="Hydraulik">
<data key="d0">1</data>
</node>
<edge source="Kontrolle" target="Lichtschranke">
<data key="d1">1</data>
</edge>
<edge source="Kontrolle" target="Schmiernippel">
<data key="d1">1</data>
</edge>
<edge source="Überprüfung" target="Spannrolle">
<data key="d1">1</data>
</edge>
<edge source="Überprüfung" target="Druckventil">
<data key="d1">1</data>
</edge>
<edge source="Überprüfung" target="Ölstand">
<data key="d1">1</data>
</edge>
<edge source="Überprüfung" target="Hydraulik">
<data key="d1">1</data>
</edge>
<edge source="Inspektion" target="Förderbänder">
<data key="d1">1</data>
</edge>
<edge source="Inspektion" target="Schutzabdeckung">
<data key="d1">1</data>
</edge>
<edge source="Reinigung" target="Luftfilter">
<data key="d1">1</data>
</edge>
</graph>
</graphml>

Binary file not shown.

1
publish.ps1 Normal file
View File

@ -0,0 +1 @@
pdm publish -r local --skip-existing

View File

@ -1,6 +1,6 @@
[project] [project]
name = "lang-main" name = "lang-main"
version = "0.1.0a1" version = "0.1.0a7"
description = "Several tools to analyse TOM's data with strong focus on language processing" description = "Several tools to analyse TOM's data with strong focus on language processing"
authors = [ authors = [
{name = "d-opt GmbH, resp. Florian Förster", email = "f.foerster@d-opt.com"}, {name = "d-opt GmbH, resp. Florian Förster", email = "f.foerster@d-opt.com"},
@ -57,6 +57,19 @@ distribution = true
[tool.pdm.build] [tool.pdm.build]
package-dir = "src" package-dir = "src"
[tool.pdm.resolution]
respect-source-order = true
[[tool.pdm.source]]
name = "private"
url = "http://localhost:8001/simple"
verify_ssl = false
[[tool.pdm.source]]
name = "pypi"
url = "https://pypi.org/simple"
exclude_packages = ["lang-main*", "tom-plugin*"]
[tool.pdm.dev-dependencies] [tool.pdm.dev-dependencies]
notebooks = [ notebooks = [
"jupyterlab>=4.2.0", "jupyterlab>=4.2.0",

View File

@ -0,0 +1,2 @@
Remove-Item "./logs" -Force -Recurse
pdm run coverage run -m pytest -m "not mload and not cyto"

View File

@ -1,3 +1,4 @@
Remove-Item "./logs" -Force -Recurse
pdm run pytest --cov -n 4 pdm run pytest --cov -n 4
# run docker desktop # run docker desktop
. "C:\Program Files\Docker\Docker\Docker Desktop.exe" . "C:\Program Files\Docker\Docker\Docker Desktop.exe"

View File

@ -16,12 +16,15 @@ from pandas import DataFrame
from lang_main.constants import ( from lang_main.constants import (
EDGE_WEIGHT_DECIMALS, EDGE_WEIGHT_DECIMALS,
LOGGING_DEFAULT_GRAPHS, LOGGING_DEFAULT_GRAPHS,
PROPERTY_NAME_BETWEENNESS_CENTRALITY,
PROPERTY_NAME_DEGREE_WEIGHTED, PROPERTY_NAME_DEGREE_WEIGHTED,
PROPERTY_NAME_IMPORTANCE,
) )
from lang_main.errors import ( from lang_main.errors import (
EdgePropertyNotContainedError, EdgePropertyNotContainedError,
EmptyEdgesError, EmptyEdgesError,
EmptyGraphError, EmptyGraphError,
NodePropertyNotContainedError,
) )
from lang_main.io import load_pickle, save_pickle from lang_main.io import load_pickle, save_pickle
from lang_main.loggers import logger_graphs as logger from lang_main.loggers import logger_graphs as logger
@ -310,15 +313,98 @@ def add_weighted_degree(
property of the edges which contains the weight information, by default 'weight' property of the edges which contains the weight information, by default 'weight'
property_name : str, optional property_name : str, optional
target name for the property containing the weighted degree in nodes, target name for the property containing the weighted degree in nodes,
by default 'degree_weighted' by default PROPERTY_NAME_DEGREE_WEIGHTED
""" """
node_degree_mapping = cast( node_property_mapping = cast(
dict[str, float], dict[str, float],
dict(graph.degree(weight=edge_weight_property)), # type: ignore dict(graph.degree(weight=edge_weight_property)), # type: ignore
) )
nx.set_node_attributes( nx.set_node_attributes(
graph, graph,
node_degree_mapping, node_property_mapping,
name=property_name,
)
def add_betweenness_centrality(
graph: DiGraph | Graph,
edge_weight_property: str | None = None,
property_name: str = PROPERTY_NAME_BETWEENNESS_CENTRALITY,
) -> None:
"""adds the betweenness centrality as property to each node of the given graph
Operation is performed inplace.
Parameters
----------
graph : DiGraph | Graph
Graph with betweenness centrality as node property added inplace
edge_weight_property : str | None, optional
property of the edges which contains the weight information,
not necessarily needed, by default 'None'
property_name : str, optional
target name for the property containing the betweenness centrality in nodes,
by default PROPERTY_NAME_BETWEENNESS_CENTRALITY
"""
node_property_mapping = cast(
dict[str, float],
nx.betweenness_centrality(graph, normalized=True, weight=edge_weight_property), # type: ignore
)
nx.set_node_attributes(
graph,
node_property_mapping,
name=property_name,
)
def add_importance_metric(
graph: DiGraph | Graph,
property_name: str = PROPERTY_NAME_IMPORTANCE,
property_name_weighted_degree: str = PROPERTY_NAME_DEGREE_WEIGHTED,
property_name_betweenness: str = PROPERTY_NAME_BETWEENNESS_CENTRALITY,
) -> None:
"""Adds a custom importance metric as property to each node of the given graph.
Can be used to decide which nodes are of high importance and also to build node size
mappings.
Operation is performed inplace.
Parameters
----------
graph : DiGraph | Graph
Graph with weighted degree as node property added inplace
property_name : str, optional
target name for the property containing the weighted degree in nodes,
by default PROPERTY_NAME_DEGREE_WEIGHTED
property_name_betweenness : str, optional
target name for the property containing the betweenness centrality in nodes,
by default PROPERTY_NAME_BETWEENNESS_CENTRALITY
"""
# build mapping for importance metric
node_property_mapping: dict[str, float] = {}
for node in cast(Iterable[str], graph.nodes):
node_data = cast(dict[str, float], graph.nodes[node])
if property_name_weighted_degree not in node_data:
raise NodePropertyNotContainedError(
(
f'Node data does not contain weighted degree '
f'with name {property_name_weighted_degree}.'
)
)
elif property_name_betweenness not in node_data:
raise NodePropertyNotContainedError(
(
f'Node data does not contain betweenness centrality '
f'with name {property_name_betweenness}.'
)
)
prio = node_data[property_name_weighted_degree] * node_data[property_name_betweenness]
node_property_mapping[node] = prio
nx.set_node_attributes(
graph,
node_property_mapping,
name=property_name, name=property_name,
) )
@ -351,6 +437,8 @@ def pipe_add_graph_metrics(
for graph in graphs: for graph in graphs:
graph_copy = copy.deepcopy(graph) graph_copy = copy.deepcopy(graph)
add_weighted_degree(graph_copy) add_weighted_degree(graph_copy)
add_betweenness_centrality(graph_copy)
add_importance_metric(graph_copy)
collection.append(graph_copy) collection.append(graph_copy)
return tuple(collection) return tuple(collection)
@ -762,19 +850,3 @@ class TokenGraph(DiGraph):
raise ValueError('File format not supported.') raise ValueError('File format not supported.')
return graph return graph
# TODO check removal
# @classmethod
# def from_pickle(
# cls,
# path: str | Path,
# ) -> Self:
# if isinstance(path, str):
# path = Path(path)
# if path.suffix not in ('.pkl', '.pickle'):
# raise ValueError('File format not supported.')
# graph = typing.cast(Self, load_pickle(path))
# return graph

View File

@ -19,7 +19,7 @@ except ImportError:
# ** external packages config # ** external packages config
# ** Huggingface Hub caching # ** Huggingface Hub caching
os.environ['HF_HUB_DISABLE_SYMLINKS_WARNING'] = 'set' os.environ['HF_HUB_DISABLE_SYMLINKS_WARNING'] = '1'
# ** py4cytoscape config # ** py4cytoscape config
if _has_py4cyto: if _has_py4cyto:
@ -36,7 +36,7 @@ BASE_FOLDERNAME: Final[str] = os.environ.get('LANG_MAIN_BASE_FOLDERNAME', 'lang-
CONFIG_FILENAME: Final[str] = 'lang_main_config.toml' CONFIG_FILENAME: Final[str] = 'lang_main_config.toml'
CYTO_STYLESHEET_FILENAME: Final[str] = r'cytoscape_config/lang_main.xml' CYTO_STYLESHEET_FILENAME: Final[str] = r'cytoscape_config/lang_main.xml'
PKG_DIR: Final[Path] = Path(__file__).parent PKG_DIR: Final[Path] = Path(__file__).parent
STOP_FOLDER: Final[str] = 'python' STOP_FOLDER: Final[str] = os.environ.get('LANG_MAIN_STOP_SEARCH_FOLDERNAME', 'src')
def load_toml_config( def load_toml_config(
@ -65,6 +65,7 @@ def load_cfg(
starting_path: Path, starting_path: Path,
glob_pattern: str, glob_pattern: str,
stop_folder_name: str | None, stop_folder_name: str | None,
lookup_cwd: bool = False,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Look for configuration file. Internal configs are not used any more because """Look for configuration file. Internal configs are not used any more because
the library behaviour is only guaranteed by external configurations. the library behaviour is only guaranteed by external configurations.
@ -91,7 +92,8 @@ def load_cfg(
LangMainConfigNotFoundError LangMainConfigNotFoundError
if no config file was found if no config file was found
""" """
cfg_path: Path | None cfg_path: Path | None = None
if lookup_cwd:
print('Looking for cfg file in CWD.', flush=True) print('Looking for cfg file in CWD.', flush=True)
cfg_path = search_cwd(glob_pattern) cfg_path = search_cwd(glob_pattern)

View File

@ -54,17 +54,13 @@ PICKLE_PROTOCOL_VERSION: Final[int] = 5
# config placed in library path of application (usually "bin") # config placed in library path of application (usually "bin")
input_path_cfg = LIB_PATH / Path(CONFIG['paths']['inputs']) input_path_cfg = LIB_PATH / Path(CONFIG['paths']['inputs'])
INPUT_PATH_FOLDER: Final[Path] = input_path_cfg.resolve() INPUT_PATH_FOLDER: Final[Path] = input_path_cfg.resolve()
# TODO reactivate later if not INPUT_PATH_FOLDER.exists(): # pragma: no cover
if not INPUT_PATH_FOLDER.exists():
raise FileNotFoundError(f'Input path >>{INPUT_PATH_FOLDER}<< does not exist.') raise FileNotFoundError(f'Input path >>{INPUT_PATH_FOLDER}<< does not exist.')
save_path_cfg = LIB_PATH / Path(CONFIG['paths']['results']) save_path_cfg = LIB_PATH / Path(CONFIG['paths']['results'])
SAVE_PATH_FOLDER: Final[Path] = save_path_cfg.resolve() SAVE_PATH_FOLDER: Final[Path] = save_path_cfg.resolve()
if not SAVE_PATH_FOLDER.exists(): if not SAVE_PATH_FOLDER.exists(): # pragma: no cover
raise FileNotFoundError(f'Output path >>{SAVE_PATH_FOLDER}<< does not exist.') raise FileNotFoundError(f'Output path >>{SAVE_PATH_FOLDER}<< does not exist.')
path_dataset_cfg = LIB_PATH / Path(CONFIG['paths']['dataset'])
PATH_TO_DATASET: Final[Path] = path_dataset_cfg.resolve()
# if not PATH_TO_DATASET.exists():
# raise FileNotFoundError(f'Dataset path >>{PATH_TO_DATASET}<< does not exist.')
# ** control # ** control
SKIP_PREPROCESSING: Final[bool] = CONFIG['control']['preprocessing_skip'] SKIP_PREPROCESSING: Final[bool] = CONFIG['control']['preprocessing_skip']
SKIP_TOKEN_ANALYSIS: Final[bool] = CONFIG['control']['token_analysis_skip'] SKIP_TOKEN_ANALYSIS: Final[bool] = CONFIG['control']['token_analysis_skip']
@ -82,22 +78,34 @@ MODEL_BASE_FOLDER: Final[Path] = model_folder_cfg.resolve()
if not MODEL_BASE_FOLDER.exists(): if not MODEL_BASE_FOLDER.exists():
raise FileNotFoundError('Language model folder not found.') raise FileNotFoundError('Language model folder not found.')
os.environ['SENTENCE_TRANSFORMERS_HOME'] = str(MODEL_BASE_FOLDER) os.environ['SENTENCE_TRANSFORMERS_HOME'] = str(MODEL_BASE_FOLDER)
SPACY_MODEL_NAME: Final[SpacyModelTypes] = SpacyModelTypes.DE_CORE_NEWS_SM
STFR_MODEL_NAME: Final[STFRModelTypes] = STFRModelTypes.ALL_MPNET_BASE_V2 # LANG_MAIN_BASE_FOLDERNAME : base folder of library, not root (folder in which Python installation is found)
# LANG_MAIN_SPACY_MODEL : spaCy model used; if not provided, use constant value defined in library; more internal use
# LANG_MAIN_STFR_MODEL : Sentence Transformer model used; if not provided, use constant value defined in library; more internal use
# LANG_MAIN_STFR_BACKEND : STFR backend, choice between "torch" and "onnx"
SPACY_MODEL_NAME: Final[str | SpacyModelTypes] = os.environ.get(
'LANG_MAIN_SPACY_MODEL', SpacyModelTypes.DE_CORE_NEWS_SM
)
STFR_MODEL_NAME: Final[str | STFRModelTypes] = os.environ.get(
'LANG_MAIN_STFR_MODEL', STFRModelTypes.ALL_MPNET_BASE_V2
)
STFR_DEVICE: Final[STFRDeviceTypes] = STFRDeviceTypes.CPU STFR_DEVICE: Final[STFRDeviceTypes] = STFRDeviceTypes.CPU
STFR_SIMILARITY: Final[SimilarityFunction] = SimilarityFunction.COSINE STFR_SIMILARITY: Final[SimilarityFunction] = SimilarityFunction.COSINE
STFR_BACKEND: Final[STFRBackends] = STFRBackends.TORCH STFR_BACKEND: Final[str | STFRBackends] = os.environ.get(
STFR_MODEL_ARGS_DEFAULT: STFRModelArgs = {} 'LANG_MAIN_STFR_BACKEND', STFRBackends.TORCH
STFR_MODEL_ARGS_ONNX: STFRModelArgs = { )
stfr_model_args_default: STFRModelArgs = {}
stfr_model_args_onnx: STFRModelArgs = {
'file_name': STFRQuantFilenames.ONNX_Q_UINT8, 'file_name': STFRQuantFilenames.ONNX_Q_UINT8,
'provider': ONNXExecutionProvider.CPU, 'provider': ONNXExecutionProvider.CPU,
'export': False, 'export': False,
} }
stfr_model_args: STFRModelArgs stfr_model_args: STFRModelArgs
if STFR_BACKEND == STFRBackends.ONNX: if STFR_BACKEND == STFRBackends.ONNX:
stfr_model_args = STFR_MODEL_ARGS_ONNX stfr_model_args = stfr_model_args_onnx
else: else:
stfr_model_args = STFR_MODEL_ARGS_DEFAULT stfr_model_args = stfr_model_args_default
STFR_MODEL_ARGS: Final[STFRModelArgs] = stfr_model_args STFR_MODEL_ARGS: Final[STFRModelArgs] = stfr_model_args
# ** language dependency analysis # ** language dependency analysis
@ -122,6 +130,8 @@ THRESHOLD_SIMILARITY: Final[float] = CONFIG['preprocess']['threshold_similarity'
EDGE_WEIGHT_DECIMALS: Final[int] = 4 EDGE_WEIGHT_DECIMALS: Final[int] = 4
THRESHOLD_EDGE_NUMBER: Final[int] = CONFIG['graph_postprocessing']['threshold_edge_number'] THRESHOLD_EDGE_NUMBER: Final[int] = CONFIG['graph_postprocessing']['threshold_edge_number']
PROPERTY_NAME_DEGREE_WEIGHTED: Final[str] = 'degree_weighted' PROPERTY_NAME_DEGREE_WEIGHTED: Final[str] = 'degree_weighted'
PROPERTY_NAME_BETWEENNESS_CENTRALITY: Final[str] = 'betweenness_centrality'
PROPERTY_NAME_IMPORTANCE: Final[str] = 'importance'
# ** graph exports (Cytoscape) # ** graph exports (Cytoscape)
CYTO_MAX_NODE_COUNT: Final[int] = 500 CYTO_MAX_NODE_COUNT: Final[int] = 500

View File

@ -0,0 +1,6 @@
# list of all library's environment variables
LANG_MAIN_STOP_SEARCH_FOLDERNAME : foldername in package directory tree at which the lookup should stop; used to find directory root
LANG_MAIN_BASE_FOLDERNAME : base folder of library, not root (folder in which Python installation is found)
LANG_MAIN_SPACY_MODEL : spaCy model used; if not provided, use constant value defined in library; more internal use
LANG_MAIN_STFR_MODEL : Sentence Transformer model used; if not provided, use constant value defined in library; more internal use
LANG_MAIN_STFR_BACKEND : STFR backend, choice between "torch" and "onnx"

View File

@ -8,6 +8,10 @@ class LanguageModelNotFoundError(Exception):
# ** token graph exceptions # ** token graph exceptions
class NodePropertyNotContainedError(Exception):
"""Error raised if a needed node property is not contained in graph edges"""
class EdgePropertyNotContainedError(Exception): class EdgePropertyNotContainedError(Exception):
"""Error raised if a needed edge property is not contained in graph edges""" """Error raised if a needed edge property is not contained in graph edges"""

View File

@ -1,22 +1,18 @@
# lang_main: Config file # lang_main: Config file
[info]
pkg = 'lang_main_internal'
[paths] [paths]
inputs = './data/in/' inputs = '../data/in/'
# results = './results/dummy_N_1000/' # results = './results/dummy_N_1000/'
# dataset = '../data/Dummy_Dataset_N_1000.csv' # dataset = '../data/Dummy_Dataset_N_1000.csv'
results = './data/out/' results = '../data/out/'
dataset = '../data/02_202307/Export4.csv' models = './lang-models'
models = '../../lang-models'
[logging] [logging]
enabled = true enabled = true
stderr = true stderr = true
file = true file = true
# only debugging features, production-ready pipelines should always # control which pipelines are executed
# be fully executed
[control] [control]
preprocessing_skip = false preprocessing_skip = false
token_analysis_skip = false token_analysis_skip = false

View File

@ -33,7 +33,7 @@ if ENABLE_LOGGING and LOGGING_TO_STDERR:
logger_all_handler_stderr = logging.StreamHandler() logger_all_handler_stderr = logging.StreamHandler()
logger_all_handler_stderr.setLevel(LOGGING_LEVEL_STDERR) logger_all_handler_stderr.setLevel(LOGGING_LEVEL_STDERR)
logger_all_handler_stderr.setFormatter(logger_all_formater) logger_all_handler_stderr.setFormatter(logger_all_formater)
else: else: # pragma: no cover
logger_all_handler_stderr = null_handler logger_all_handler_stderr = null_handler
if ENABLE_LOGGING and LOGGING_TO_FILE: if ENABLE_LOGGING and LOGGING_TO_FILE:
@ -45,7 +45,7 @@ if ENABLE_LOGGING and LOGGING_TO_FILE:
) )
logger_all_handler_file.setLevel(LOGGING_LEVEL_FILE) logger_all_handler_file.setLevel(LOGGING_LEVEL_FILE)
logger_all_handler_file.setFormatter(logger_all_formater) logger_all_handler_file.setFormatter(logger_all_formater)
else: else: # pragma: no cover
logger_all_handler_file = null_handler logger_all_handler_file = null_handler

View File

@ -33,6 +33,7 @@ class BasePipeline(ABC):
# container for actions to perform during pass # container for actions to perform during pass
self.actions: list[Callable] = [] self.actions: list[Callable] = []
self.action_names: list[str] = [] self.action_names: list[str] = []
self.action_skip: list[bool] = []
# progress tracking, start at 1 # progress tracking, start at 1
self.curr_proc_idx: int = 1 self.curr_proc_idx: int = 1
@ -104,8 +105,6 @@ class PipelineContainer(BasePipeline):
) -> None: ) -> None:
super().__init__(name=name, working_dir=working_dir) super().__init__(name=name, working_dir=working_dir)
self.action_skip: list[bool] = []
@override @override
def add( def add(
self, self,
@ -170,6 +169,7 @@ class Pipeline(BasePipeline):
self, self,
action: Callable, action: Callable,
action_kwargs: dict[str, Any] | None = None, action_kwargs: dict[str, Any] | None = None,
skip: bool = False,
save_result: bool = False, save_result: bool = False,
load_result: bool = False, load_result: bool = False,
filename: str | None = None, filename: str | None = None,
@ -183,6 +183,7 @@ class Pipeline(BasePipeline):
self.actions.append(action) self.actions.append(action)
self.action_names.append(action.__name__) self.action_names.append(action.__name__)
self.actions_kwargs.append(action_kwargs.copy()) self.actions_kwargs.append(action_kwargs.copy())
self.action_skip.append(skip)
self.save_results.append((save_result, filename)) self.save_results.append((save_result, filename))
self.load_results.append((load_result, filename)) self.load_results.append((load_result, filename))
else: else:
@ -235,7 +236,13 @@ class Pipeline(BasePipeline):
self, self,
starting_values: tuple[Any, ...] | None = None, starting_values: tuple[Any, ...] | None = None,
) -> tuple[Any, ...]: ) -> tuple[Any, ...]:
first_performed: bool = False
for idx, (action, action_kwargs) in enumerate(zip(self.actions, self.actions_kwargs)): for idx, (action, action_kwargs) in enumerate(zip(self.actions, self.actions_kwargs)):
if self.action_skip[idx]:
self.curr_proc_idx += 1
continue
# loading # loading
if self.load_results[idx][0]: if self.load_results[idx][0]:
filename = self.load_results[idx][1] filename = self.load_results[idx][1]
@ -248,8 +255,9 @@ class Pipeline(BasePipeline):
self.curr_proc_idx += 1 self.curr_proc_idx += 1
continue continue
# calculation # calculation
if idx == 0: if not first_performed:
args = starting_values args = starting_values
first_performed = True
else: else:
args = ret args = ret

View File

@ -296,7 +296,7 @@ def apply_style_to_network(
style_name: str = CYTO_STYLESHEET_NAME, style_name: str = CYTO_STYLESHEET_NAME,
pth_to_stylesheet: Path = CYTO_PATH_STYLESHEET, pth_to_stylesheet: Path = CYTO_PATH_STYLESHEET,
network_name: str = CYTO_BASE_NETWORK_NAME, network_name: str = CYTO_BASE_NETWORK_NAME,
node_size_property: str = 'node_selection', node_size_property: str = CYTO_SELECTION_PROPERTY,
min_node_size: int = 15, min_node_size: int = 15,
max_node_size: int = 40, max_node_size: int = 40,
sandbox_name: str = CYTO_SANDBOX_NAME, sandbox_name: str = CYTO_SANDBOX_NAME,

View File

@ -2,7 +2,12 @@ import networkx as nx
import pytest import pytest
from lang_main.analysis import graphs from lang_main.analysis import graphs
from lang_main.errors import EmptyEdgesError, EmptyGraphError, EdgePropertyNotContainedError from lang_main.errors import (
EdgePropertyNotContainedError,
EmptyEdgesError,
EmptyGraphError,
NodePropertyNotContainedError,
)
TK_GRAPH_NAME = 'TEST_TOKEN_GRAPH' TK_GRAPH_NAME = 'TEST_TOKEN_GRAPH'
@ -231,6 +236,49 @@ def test_add_weighted_degree():
assert graph_obj.nodes[3][property_name] == 6 assert graph_obj.nodes[3][property_name] == 6
def test_add_betweenness_centrality():
graph_obj = build_init_graph(token_graph=False)
property_name = 'betweenness_centrality'
graphs.add_betweenness_centrality(graph_obj, property_name=property_name)
assert round(graph_obj.nodes[1][property_name], 4) == pytest.approx(0.1667)
assert graph_obj.nodes[2][property_name] == 0
assert graph_obj.nodes[3][property_name] == 0
def test_add_importance_metric():
graph_obj = build_init_graph(token_graph=False)
property_name_WD = 'degree_weighted'
graphs.add_weighted_degree(graph_obj, 'weight', property_name_WD)
property_name_BC = 'betweenness_centrality'
graphs.add_betweenness_centrality(graph_obj, property_name=property_name_BC)
property_name = 'importance'
graphs.add_importance_metric(
graph_obj,
property_name=property_name,
property_name_weighted_degree=property_name_WD,
property_name_betweenness=property_name_BC,
)
assert round(graph_obj.nodes[1][property_name], 4) == pytest.approx(2.3333)
assert graph_obj.nodes[2][property_name] == 0
assert graph_obj.nodes[3][property_name] == 0
with pytest.raises(NodePropertyNotContainedError):
graphs.add_importance_metric(
graph_obj,
property_name=property_name,
property_name_weighted_degree='prop_not_contained',
property_name_betweenness=property_name_BC,
)
with pytest.raises(NodePropertyNotContainedError):
graphs.add_importance_metric(
graph_obj,
property_name=property_name,
property_name_weighted_degree=property_name_WD,
property_name_betweenness='prop_not_contained',
)
def test_static_graph_analysis(): def test_static_graph_analysis():
graph_obj = build_init_graph(token_graph=True) graph_obj = build_init_graph(token_graph=True)
(graph_obj,) = graphs.static_graph_analysis(graph_obj) # type: ignore (graph_obj,) = graphs.static_graph_analysis(graph_obj) # type: ignore
@ -254,6 +302,20 @@ def test_pipe_add_graph_metrics():
assert graph_collection[1].nodes[1][property_name] == 14 assert graph_collection[1].nodes[1][property_name] == 14
assert graph_collection[1].nodes[2][property_name] == 10 assert graph_collection[1].nodes[2][property_name] == 10
assert graph_collection[1].nodes[3][property_name] == 6 assert graph_collection[1].nodes[3][property_name] == 6
property_name = 'betweenness_centrality'
assert round(graph_collection[0].nodes[1][property_name], 4) == pytest.approx(0.1667)
assert graph_collection[0].nodes[2][property_name] == 0
assert graph_collection[0].nodes[3][property_name] == 0
assert round(graph_collection[1].nodes[1][property_name], 4) == pytest.approx(0.1667)
assert graph_collection[1].nodes[2][property_name] == 0
assert graph_collection[1].nodes[3][property_name] == 0
property_name = 'importance'
assert round(graph_collection[0].nodes[1][property_name], 4) == pytest.approx(2.3333)
assert graph_collection[0].nodes[2][property_name] == 0
assert graph_collection[0].nodes[3][property_name] == 0
assert round(graph_collection[1].nodes[1][property_name], 4) == pytest.approx(2.3333)
assert graph_collection[1].nodes[2][property_name] == 0
assert graph_collection[1].nodes[3][property_name] == 0
def test_pipe_rescale_graph_edge_weights(tk_graph): def test_pipe_rescale_graph_edge_weights(tk_graph):

View File

@ -121,6 +121,7 @@ def test_pipeline_valid(pipeline, alter_content):
assert len(pipe.actions) == 1 assert len(pipe.actions) == 1
assert len(pipe.action_names) == 1 assert len(pipe.action_names) == 1
assert len(pipe.actions_kwargs) == 1 assert len(pipe.actions_kwargs) == 1
assert len(pipe.action_skip) == 1
assert len(pipe.save_results) == 1 assert len(pipe.save_results) == 1
assert len(pipe.load_results) == 1 assert len(pipe.load_results) == 1
assert pipe.save_results[0] == (False, None) assert pipe.save_results[0] == (False, None)
@ -166,7 +167,7 @@ def test_pipeline_valid_action_load(pipeline, working_dir):
test_string = 'test' test_string = 'test'
# action preparation # action preparation
def valid_action(string, add_content=False): def valid_action(string, add_content=False): # pragma: no cover
if add_content: if add_content:
string += '_2' string += '_2'
return string return string
@ -175,6 +176,7 @@ def test_pipeline_valid_action_load(pipeline, working_dir):
assert len(pipe.actions) == 1 assert len(pipe.actions) == 1
assert len(pipe.action_names) == 1 assert len(pipe.action_names) == 1
assert len(pipe.actions_kwargs) == 1 assert len(pipe.actions_kwargs) == 1
assert len(pipe.action_skip) == 1
assert len(pipe.save_results) == 1 assert len(pipe.save_results) == 1
assert len(pipe.load_results) == 1 assert len(pipe.load_results) == 1
assert pipe.save_results[0] == (False, None) assert pipe.save_results[0] == (False, None)
@ -209,19 +211,28 @@ def test_pipeline_multiple_actions(pipeline):
string += '_3' string += '_3'
return string return string
pipe.add(valid_action, {'add_content': True}, skip=True)
pipe.add(valid_action, {'add_content': True}) pipe.add(valid_action, {'add_content': True})
pipe.add(valid_action_2) pipe.add(valid_action_2)
assert len(pipe.actions) == 2 assert len(pipe.actions) == 3
assert len(pipe.action_names) == 2 assert len(pipe.action_names) == 3
assert len(pipe.actions_kwargs) == 2 assert len(pipe.actions_kwargs) == 3
assert len(pipe.save_results) == 2 assert len(pipe.action_skip) == 3
assert len(pipe.load_results) == 2 assert len(pipe.save_results) == 3
assert len(pipe.load_results) == 3
assert pipe.save_results[1] == (False, None) assert pipe.save_results[1] == (False, None)
assert pipe.load_results[1] == (False, None) assert pipe.load_results[1] == (False, None)
ret = pipe.run(starting_values=(test_string,)) ret = pipe.run(starting_values=(test_string,))
assert isinstance(ret, tuple) assert isinstance(ret, tuple)
assert pipe._intermediate_result == ret assert pipe._intermediate_result == ret
assert pipe.curr_proc_idx == 3 assert pipe.curr_proc_idx == 4
assert ret is not None assert ret is not None
assert ret[0] == 'test_2_3' assert ret[0] == 'test_2_3'
def test_pipeline_invalid_action(pipeline):
test_string = 'test'
with pytest.raises(WrongActionTypeError):
pipeline.add(test_string, skip=False)

View File

@ -0,0 +1,54 @@
import builtins
import importlib.util
from importlib import reload
import pytest
from lang_main.errors import DependencyMissingError
@pytest.fixture(scope='function')
def no_dep(monkeypatch):
import_orig = builtins.__import__
def mocked_import(name, globals, locals, fromlist, level):
if name == 'py4cytoscape':
raise ImportError()
return import_orig(name, locals, fromlist, level)
monkeypatch.setattr(builtins, '__import__', mocked_import)
@pytest.fixture(scope='function')
def patch_find_spec(monkeypatch):
find_spec_orig = importlib.util.find_spec
def mocked_find_spec(*args, **kwargs):
if args[0] == 'py4cytoscape':
return None
else:
return find_spec_orig(*args, **kwargs)
monkeypatch.setattr(importlib.util, 'find_spec', mocked_find_spec)
def test_p4c_available():
import lang_main.constants
reload(lang_main.constants)
assert lang_main.constants.Dependencies.PY4C.value
@pytest.mark.usefixtures('patch_find_spec')
def test_p4c_missing(monkeypatch):
import lang_main.constants
reload(lang_main.constants)
assert not lang_main.constants.Dependencies.PY4C.value
with pytest.raises(DependencyMissingError):
from lang_main import render
reload(render)

View File

@ -16,7 +16,7 @@ def test_p4c_dependency():
def test_load_config(): def test_load_config():
toml_path = config.PKG_DIR / 'lang_main_config.toml' toml_path = config.PKG_DIR / 'lang_main_config.toml'
loaded_cfg = config.load_toml_config(toml_path) loaded_cfg = config.load_toml_config(toml_path)
assert loaded_cfg['info']['pkg'] == 'lang_main_internal' assert loaded_cfg['paths']['models'] == './lang-models'
def test_get_config_path(): def test_get_config_path():
@ -36,7 +36,7 @@ def test_get_config_path():
assert cyto_internal == cyto_cfg_pth assert cyto_internal == cyto_cfg_pth
def test_load_cfg(monkeypatch, tmp_path): def test_load_cfg_func(monkeypatch, tmp_path):
monkeypatch.setattr(Path, 'cwd', lambda: tmp_path) monkeypatch.setattr(Path, 'cwd', lambda: tmp_path)
pkg_dir = config.PKG_DIR pkg_dir = config.PKG_DIR
filename = config.CONFIG_FILENAME filename = config.CONFIG_FILENAME
@ -44,21 +44,20 @@ def test_load_cfg(monkeypatch, tmp_path):
cfg_pth_internal = (pkg_dir / filename).resolve() cfg_pth_internal = (pkg_dir / filename).resolve()
ref_config = config.load_toml_config(cfg_pth_internal) ref_config = config.load_toml_config(cfg_pth_internal)
assert ref_config['paths']['models'] == './lang-models'
assert ref_config['info']['pkg'] == 'lang_main_internal'
loaded_cfg = config.load_cfg( loaded_cfg = config.load_cfg(
starting_path=pkg_dir, starting_path=pkg_dir,
glob_pattern=filename, glob_pattern=filename,
stop_folder_name=stop_folder, stop_folder_name=stop_folder,
cfg_path_internal=cfg_pth_internal, lookup_cwd=False,
prefer_internal_config=True,
) )
assert loaded_cfg['info']['pkg'] == 'lang_main_internal' assert loaded_cfg['paths']['models'] == '../lang-models'
loaded_cfg = config.load_cfg( loaded_cfg = config.load_cfg(
starting_path=pkg_dir, starting_path=pkg_dir,
glob_pattern=filename, glob_pattern=filename,
stop_folder_name=stop_folder, stop_folder_name=stop_folder,
cfg_path_internal=cfg_pth_internal, lookup_cwd=True,
prefer_internal_config=False,
) )
assert loaded_cfg['info']['pkg'] == 'lang_main' assert loaded_cfg['paths']['models'] == '../lang-models'

View File

@ -4,12 +4,12 @@ from spacy.language import Language
from lang_main import model_loader from lang_main import model_loader
from lang_main.constants import ( from lang_main.constants import (
STFR_MODEL_ARGS_ONNX,
SimilarityFunction, SimilarityFunction,
SpacyModelTypes, SpacyModelTypes,
STFRBackends, STFRBackends,
STFRDeviceTypes, STFRDeviceTypes,
STFRModelTypes, STFRModelTypes,
stfr_model_args_onnx,
) )
from lang_main.errors import LanguageModelNotFoundError from lang_main.errors import LanguageModelNotFoundError
from lang_main.types import LanguageModels from lang_main.types import LanguageModels
@ -69,7 +69,7 @@ def test_load_sentence_transformer_onnx(model_name, similarity_func) -> None:
similarity_func=similarity_func, similarity_func=similarity_func,
backend=STFRBackends.ONNX, backend=STFRBackends.ONNX,
device=STFRDeviceTypes.CPU, device=STFRDeviceTypes.CPU,
model_kwargs=STFR_MODEL_ARGS_ONNX, # type: ignore model_kwargs=stfr_model_args_onnx, # type: ignore
) )
assert isinstance(model, SentenceTransformer) assert isinstance(model, SentenceTransformer)

Binary file not shown.