added new graph metrics
This commit is contained in:
parent
123869e203
commit
80a35c4658
@ -1,27 +1,24 @@
|
|||||||
# lang_main: Config file
|
# lang_main: Config file
|
||||||
[info]
|
|
||||||
pkg = 'lang_main'
|
|
||||||
|
|
||||||
[paths]
|
[paths]
|
||||||
inputs = './inputs/'
|
inputs = './data/'
|
||||||
# results = './results/dummy_N_1000/'
|
# results = './results/dummy_N_1000/'
|
||||||
# dataset = '../data/Dummy_Dataset_N_1000.csv'
|
# dataset = '../data/Dummy_Dataset_N_1000.csv'
|
||||||
results = './results/test_20240807/'
|
results = './data/'
|
||||||
dataset = '../data/02_202307/Export4.csv'
|
models = '../lang-models'
|
||||||
|
|
||||||
[logging]
|
[logging]
|
||||||
enabled = true
|
enabled = true
|
||||||
stderr = true
|
stderr = true
|
||||||
file = true
|
file = true
|
||||||
|
|
||||||
# only debugging features, production-ready pipelines should always
|
# control which pipelines are executed
|
||||||
# be fully executed
|
|
||||||
[control]
|
[control]
|
||||||
preprocessing_skip = true
|
preprocessing_skip = false
|
||||||
token_analysis_skip = false
|
token_analysis_skip = false
|
||||||
graph_postprocessing_skip = false
|
graph_postprocessing_skip = false
|
||||||
graph_rescaling_skip = false
|
graph_rescaling_skip = false
|
||||||
graph_static_rendering_skip = false
|
graph_static_rendering_skip = true
|
||||||
time_analysis_skip = true
|
time_analysis_skip = true
|
||||||
|
|
||||||
[preprocess]
|
[preprocess]
|
||||||
|
|||||||
@ -155,7 +155,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 72,
|
"execution_count": 1,
|
||||||
"id": "0a48d11d-1f2b-475e-9ddf-bb9a3f67accb",
|
"id": "0a48d11d-1f2b-475e-9ddf-bb9a3f67accb",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -165,53 +165,471 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 73,
|
"execution_count": 2,
|
||||||
"id": "e340377a-0df4-44ca-b18e-8b354e273eb9",
|
"id": "e340377a-0df4-44ca-b18e-8b354e273eb9",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"save_pth = Path.cwd() / 'test.graphml'"
|
"save_pth = Path.cwd() / 'tk_graph_built.pkl'\n",
|
||||||
|
"pth_export = Path.cwd() / 'tk_graph_built'\n",
|
||||||
|
"assert save_pth.exists()"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 82,
|
"execution_count": 3,
|
||||||
"id": "66677ad0-a1e5-4772-a0ba-7fbeeda55297",
|
"id": "8aba87b2-e924-4748-98c6-05c4676f3c08",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Looking iteratively for config file. Start: A:\\Arbeitsaufgaben\\lang-main\\src\\lang_main, stop folder: src\n",
|
||||||
|
"Loaded TOML config file successfully.\n",
|
||||||
|
"Loaded config from: >>A:\\Arbeitsaufgaben\\lang-main\\lang_main_config.toml<<\n",
|
||||||
|
"Library path is: A:\\Arbeitsaufgaben\\lang-main\n",
|
||||||
|
"Root path is: A:\\Arbeitsaufgaben\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"nx.write_graphml(G, save_pth)"
|
"from lang_main import io"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 84,
|
"execution_count": 5,
|
||||||
"id": "f01ebe25-56b9-410a-a2bf-d5a6e211de7a",
|
"id": "728fdce3-cfe0-4c4b-bcbf-c3d61547e94f",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"2024-12-19 12:24:12 +0000 | lang_main:io:INFO | Loaded file successfully.\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"G_load = nx.read_graphml(save_pth, node_type=int)"
|
"G = io.load_pickle(save_pth)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 85,
|
"execution_count": 6,
|
||||||
"id": "10bfad35-1f96-41a1-9014-578313502e6c",
|
"id": "fcb74134-4192-4d68-a535-f5a502d02b67",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"OutEdgeView([(1, 2), (1, 3), (1, 4), (2, 4), (2, 1), (3, 4)])"
|
"<networkx.classes.graph.Graph at 0x222d01224d0>"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 85,
|
"execution_count": 6,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"G_load.edges"
|
"G.undirected"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "024560e5-373a-46f4-b0f7-7794535daa78",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"id": "65745c09-e834-47c9-b761-aabd5fa03e57",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"2024-12-19 11:41:58 +0000 | lang_main:graphs:INFO | Successfully saved graph as GraphML file under A:\\Arbeitsaufgaben\\lang-main\\notebooks\\tk_graph_built.graphml.\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"G.to_GraphML(Path.cwd(), 'tk_graph_built')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 35,
|
||||||
|
"id": "8da93a34-7bb1-4b9c-851b-0793c2c483bf",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"TokenGraph(name: TokenGraph, number of nodes: 13, number of edges: 9)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 35,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"G"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 28,
|
||||||
|
"id": "1cc98c6e-4fa5-49ed-bb97-c62d8e103ccc",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"G_copy = G.copy()\n",
|
||||||
|
"G_copy = G_copy.undirected"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 29,
|
||||||
|
"id": "a5044a6f-8903-4e33-a56e-abf647019d8a",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"{'degree_weighted': 2}\n",
|
||||||
|
"{'degree_weighted': 1}\n",
|
||||||
|
"{'degree_weighted': 4}\n",
|
||||||
|
"{'degree_weighted': 1}\n",
|
||||||
|
"{'degree_weighted': 1}\n",
|
||||||
|
"{'degree_weighted': 1}\n",
|
||||||
|
"{'degree_weighted': 2}\n",
|
||||||
|
"{'degree_weighted': 1}\n",
|
||||||
|
"{'degree_weighted': 1}\n",
|
||||||
|
"{'degree_weighted': 1}\n",
|
||||||
|
"{'degree_weighted': 1}\n",
|
||||||
|
"{'degree_weighted': 1}\n",
|
||||||
|
"{'degree_weighted': 1}\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"for node in G_copy.nodes:\n",
|
||||||
|
" print(G_copy.nodes[node])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "720015a4-7338-4fce-ac6c-38e691e0efa4",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 76,
|
||||||
|
"id": "b21a34c8-5748-42b1-947e-76c3ba02a240",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 77,
|
||||||
|
"id": "1bcd4a31-2ea5-4123-8fc3-27017acd7259",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"{'weight': 1}\n",
|
||||||
|
"{'weight': 1}\n",
|
||||||
|
"{'weight': 20}\n",
|
||||||
|
"{'weight': 15}\n",
|
||||||
|
"{'weight': 10}\n",
|
||||||
|
"{'weight': 6}\n",
|
||||||
|
"{'weight': 1}\n",
|
||||||
|
"{'weight': 1}\n",
|
||||||
|
"{'weight': 1}\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"for edge in G_copy.edges:\n",
|
||||||
|
" print(G_copy.edges[edge])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "bfe5bacb-1f51-45d8-b920-ca9cba01282c",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 32,
|
||||||
|
"id": "48e14d11-0319-47a9-81be-c696f69d55b3",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import networkx as nx"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 33,
|
||||||
|
"id": "479deea9-a44f-48eb-95d3-7a326a83f62c",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"mapping = nx.betweenness_centrality(G_copy, normalized=True)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 34,
|
||||||
|
"id": "2299dc98-0a82-4009-8b38-790ffe3a3fd8",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"{'Kontrolle': 0.015151515151515152,\n",
|
||||||
|
" 'Lichtschranke': 0.0,\n",
|
||||||
|
" 'Überprüfung': 0.09090909090909091,\n",
|
||||||
|
" 'Spannrolle': 0.0,\n",
|
||||||
|
" 'Druckventil': 0.0,\n",
|
||||||
|
" 'Schmiernippel': 0.0,\n",
|
||||||
|
" 'Inspektion': 0.015151515151515152,\n",
|
||||||
|
" 'Förderbänder': 0.0,\n",
|
||||||
|
" 'Reinigung': 0.0,\n",
|
||||||
|
" 'Luftfilter': 0.0,\n",
|
||||||
|
" 'Schutzabdeckung': 0.0,\n",
|
||||||
|
" 'Ölstand': 0.0,\n",
|
||||||
|
" 'Hydraulik': 0.0}"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 34,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"mapping"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 35,
|
||||||
|
"id": "4c00dfe4-29d3-45d8-9eec-6d8e1fbda910",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"nx.set_node_attributes(G_copy, mapping, name='BC')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 36,
|
||||||
|
"id": "0e1119e1-700d-4ae0-bb36-6d4a8fe57698",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"{'degree_weighted': 2, 'BC': 0.015151515151515152}\n",
|
||||||
|
"{'degree_weighted': 1, 'BC': 0.0}\n",
|
||||||
|
"{'degree_weighted': 4, 'BC': 0.09090909090909091}\n",
|
||||||
|
"{'degree_weighted': 1, 'BC': 0.0}\n",
|
||||||
|
"{'degree_weighted': 1, 'BC': 0.0}\n",
|
||||||
|
"{'degree_weighted': 1, 'BC': 0.0}\n",
|
||||||
|
"{'degree_weighted': 2, 'BC': 0.015151515151515152}\n",
|
||||||
|
"{'degree_weighted': 1, 'BC': 0.0}\n",
|
||||||
|
"{'degree_weighted': 1, 'BC': 0.0}\n",
|
||||||
|
"{'degree_weighted': 1, 'BC': 0.0}\n",
|
||||||
|
"{'degree_weighted': 1, 'BC': 0.0}\n",
|
||||||
|
"{'degree_weighted': 1, 'BC': 0.0}\n",
|
||||||
|
"{'degree_weighted': 1, 'BC': 0.0}\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"nodes_prop_mapping = {}\n",
|
||||||
|
"for node in G_copy.nodes:\n",
|
||||||
|
" node_data = G_copy.nodes[node]\n",
|
||||||
|
" prio = node_data['degree_weighted'] * node_data['BC']\n",
|
||||||
|
" nodes_prop_mapping[node] = prio\n",
|
||||||
|
" print(G_copy.nodes[node])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 37,
|
||||||
|
"id": "2970267e-5c9e-472f-857b-382e2406f34d",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"{'Kontrolle': 0.030303030303030304,\n",
|
||||||
|
" 'Lichtschranke': 0.0,\n",
|
||||||
|
" 'Überprüfung': 0.36363636363636365,\n",
|
||||||
|
" 'Spannrolle': 0.0,\n",
|
||||||
|
" 'Druckventil': 0.0,\n",
|
||||||
|
" 'Schmiernippel': 0.0,\n",
|
||||||
|
" 'Inspektion': 0.030303030303030304,\n",
|
||||||
|
" 'Förderbänder': 0.0,\n",
|
||||||
|
" 'Reinigung': 0.0,\n",
|
||||||
|
" 'Luftfilter': 0.0,\n",
|
||||||
|
" 'Schutzabdeckung': 0.0,\n",
|
||||||
|
" 'Ölstand': 0.0,\n",
|
||||||
|
" 'Hydraulik': 0.0}"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 37,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"nodes_prop_mapping"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 38,
|
||||||
|
"id": "ef2a1065-9a47-4df4-a7d9-ccfe30dbcdc7",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"nx.set_node_attributes(G_copy, nodes_prop_mapping, name='prio')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 39,
|
||||||
|
"id": "2b77597a-be7a-498a-8374-f98324574619",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"{'degree_weighted': 2, 'BC': 0.015151515151515152, 'prio': 0.030303030303030304}\n",
|
||||||
|
"{'degree_weighted': 1, 'BC': 0.0, 'prio': 0.0}\n",
|
||||||
|
"{'degree_weighted': 4, 'BC': 0.09090909090909091, 'prio': 0.36363636363636365}\n",
|
||||||
|
"{'degree_weighted': 1, 'BC': 0.0, 'prio': 0.0}\n",
|
||||||
|
"{'degree_weighted': 1, 'BC': 0.0, 'prio': 0.0}\n",
|
||||||
|
"{'degree_weighted': 1, 'BC': 0.0, 'prio': 0.0}\n",
|
||||||
|
"{'degree_weighted': 2, 'BC': 0.015151515151515152, 'prio': 0.030303030303030304}\n",
|
||||||
|
"{'degree_weighted': 1, 'BC': 0.0, 'prio': 0.0}\n",
|
||||||
|
"{'degree_weighted': 1, 'BC': 0.0, 'prio': 0.0}\n",
|
||||||
|
"{'degree_weighted': 1, 'BC': 0.0, 'prio': 0.0}\n",
|
||||||
|
"{'degree_weighted': 1, 'BC': 0.0, 'prio': 0.0}\n",
|
||||||
|
"{'degree_weighted': 1, 'BC': 0.0, 'prio': 0.0}\n",
|
||||||
|
"{'degree_weighted': 1, 'BC': 0.0, 'prio': 0.0}\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"for node in G_copy.nodes:\n",
|
||||||
|
" print(G_copy.nodes[node])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "089ae831-f307-473d-8a70-e72bc580ca61",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "81d9e8ad-4372-4e3b-a42b-e5b343802936",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 15,
|
||||||
|
"id": "ea0a1fc3-571f-4ce2-8561-b0a8f93e5072",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def build_init_graph():\n",
|
||||||
|
" edge_weights = [\n",
|
||||||
|
" {'weight': 1},\n",
|
||||||
|
" {'weight': 2},\n",
|
||||||
|
" {'weight': 3},\n",
|
||||||
|
" {'weight': 4},\n",
|
||||||
|
" {'weight': 5},\n",
|
||||||
|
" {'weight': 6},\n",
|
||||||
|
" ]\n",
|
||||||
|
" edges = [\n",
|
||||||
|
" (1, 2),\n",
|
||||||
|
" (1, 3),\n",
|
||||||
|
" (2, 4),\n",
|
||||||
|
" (3, 4),\n",
|
||||||
|
" (1, 4),\n",
|
||||||
|
" (2, 1),\n",
|
||||||
|
" ]\n",
|
||||||
|
" edges_to_add = []\n",
|
||||||
|
" for i, edge in enumerate(edges):\n",
|
||||||
|
" edge = list(edge)\n",
|
||||||
|
" edge.append(edge_weights[i]) # type: ignore\n",
|
||||||
|
" edges_to_add.append(tuple(edge))\n",
|
||||||
|
"\n",
|
||||||
|
" G = nx.DiGraph()\n",
|
||||||
|
"\n",
|
||||||
|
" G.add_edges_from(edges_to_add)\n",
|
||||||
|
"\n",
|
||||||
|
" return G"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 16,
|
||||||
|
"id": "6e97af5f-f629-4c71-a1e4-2310ad2f3caa",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"G_init = build_init_graph()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 17,
|
||||||
|
"id": "4ff8fb6b-a10b-448f-b8a5-9989f4bee81e",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"{1: 0.16666666666666666, 2: 0.0, 3: 0.0, 4: 0.0}"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 17,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"nx.betweenness_centrality(G_init, normalized=True)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
73
notebooks/tk_graph_built.graphml
Normal file
73
notebooks/tk_graph_built.graphml
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
<?xml version='1.0' encoding='utf-8'?>
|
||||||
|
<graphml xmlns="http://graphml.graphdrawing.org/xmlns" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://graphml.graphdrawing.org/xmlns http://graphml.graphdrawing.org/xmlns/1.0/graphml.xsd">
|
||||||
|
<key id="d1" for="edge" attr.name="weight" attr.type="long" />
|
||||||
|
<key id="d0" for="node" attr.name="degree_weighted" attr.type="long" />
|
||||||
|
<graph edgedefault="undirected">
|
||||||
|
<node id="Kontrolle">
|
||||||
|
<data key="d0">2</data>
|
||||||
|
</node>
|
||||||
|
<node id="Lichtschranke">
|
||||||
|
<data key="d0">1</data>
|
||||||
|
</node>
|
||||||
|
<node id="Überprüfung">
|
||||||
|
<data key="d0">4</data>
|
||||||
|
</node>
|
||||||
|
<node id="Spannrolle">
|
||||||
|
<data key="d0">1</data>
|
||||||
|
</node>
|
||||||
|
<node id="Druckventil">
|
||||||
|
<data key="d0">1</data>
|
||||||
|
</node>
|
||||||
|
<node id="Schmiernippel">
|
||||||
|
<data key="d0">1</data>
|
||||||
|
</node>
|
||||||
|
<node id="Inspektion">
|
||||||
|
<data key="d0">2</data>
|
||||||
|
</node>
|
||||||
|
<node id="Förderbänder">
|
||||||
|
<data key="d0">1</data>
|
||||||
|
</node>
|
||||||
|
<node id="Reinigung">
|
||||||
|
<data key="d0">1</data>
|
||||||
|
</node>
|
||||||
|
<node id="Luftfilter">
|
||||||
|
<data key="d0">1</data>
|
||||||
|
</node>
|
||||||
|
<node id="Schutzabdeckung">
|
||||||
|
<data key="d0">1</data>
|
||||||
|
</node>
|
||||||
|
<node id="Ölstand">
|
||||||
|
<data key="d0">1</data>
|
||||||
|
</node>
|
||||||
|
<node id="Hydraulik">
|
||||||
|
<data key="d0">1</data>
|
||||||
|
</node>
|
||||||
|
<edge source="Kontrolle" target="Lichtschranke">
|
||||||
|
<data key="d1">1</data>
|
||||||
|
</edge>
|
||||||
|
<edge source="Kontrolle" target="Schmiernippel">
|
||||||
|
<data key="d1">1</data>
|
||||||
|
</edge>
|
||||||
|
<edge source="Überprüfung" target="Spannrolle">
|
||||||
|
<data key="d1">1</data>
|
||||||
|
</edge>
|
||||||
|
<edge source="Überprüfung" target="Druckventil">
|
||||||
|
<data key="d1">1</data>
|
||||||
|
</edge>
|
||||||
|
<edge source="Überprüfung" target="Ölstand">
|
||||||
|
<data key="d1">1</data>
|
||||||
|
</edge>
|
||||||
|
<edge source="Überprüfung" target="Hydraulik">
|
||||||
|
<data key="d1">1</data>
|
||||||
|
</edge>
|
||||||
|
<edge source="Inspektion" target="Förderbänder">
|
||||||
|
<data key="d1">1</data>
|
||||||
|
</edge>
|
||||||
|
<edge source="Inspektion" target="Schutzabdeckung">
|
||||||
|
<data key="d1">1</data>
|
||||||
|
</edge>
|
||||||
|
<edge source="Reinigung" target="Luftfilter">
|
||||||
|
<data key="d1">1</data>
|
||||||
|
</edge>
|
||||||
|
</graph>
|
||||||
|
</graphml>
|
||||||
BIN
notebooks/tk_graph_built.pkl
Normal file
BIN
notebooks/tk_graph_built.pkl
Normal file
Binary file not shown.
1
publish.ps1
Normal file
1
publish.ps1
Normal file
@ -0,0 +1 @@
|
|||||||
|
pdm publish -r local --skip-existing
|
||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "lang-main"
|
name = "lang-main"
|
||||||
version = "0.1.0a1"
|
version = "0.1.0a7"
|
||||||
description = "Several tools to analyse TOM's data with strong focus on language processing"
|
description = "Several tools to analyse TOM's data with strong focus on language processing"
|
||||||
authors = [
|
authors = [
|
||||||
{name = "d-opt GmbH, resp. Florian Förster", email = "f.foerster@d-opt.com"},
|
{name = "d-opt GmbH, resp. Florian Förster", email = "f.foerster@d-opt.com"},
|
||||||
@ -57,6 +57,19 @@ distribution = true
|
|||||||
[tool.pdm.build]
|
[tool.pdm.build]
|
||||||
package-dir = "src"
|
package-dir = "src"
|
||||||
|
|
||||||
|
[tool.pdm.resolution]
|
||||||
|
respect-source-order = true
|
||||||
|
|
||||||
|
[[tool.pdm.source]]
|
||||||
|
name = "private"
|
||||||
|
url = "http://localhost:8001/simple"
|
||||||
|
verify_ssl = false
|
||||||
|
|
||||||
|
[[tool.pdm.source]]
|
||||||
|
name = "pypi"
|
||||||
|
url = "https://pypi.org/simple"
|
||||||
|
exclude_packages = ["lang-main*", "tom-plugin*"]
|
||||||
|
|
||||||
[tool.pdm.dev-dependencies]
|
[tool.pdm.dev-dependencies]
|
||||||
notebooks = [
|
notebooks = [
|
||||||
"jupyterlab>=4.2.0",
|
"jupyterlab>=4.2.0",
|
||||||
|
|||||||
2
run_test_wo_models+cyto.ps1
Normal file
2
run_test_wo_models+cyto.ps1
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
Remove-Item "./logs" -Force -Recurse
|
||||||
|
pdm run coverage run -m pytest -m "not mload and not cyto"
|
||||||
@ -1,3 +1,4 @@
|
|||||||
|
Remove-Item "./logs" -Force -Recurse
|
||||||
pdm run pytest --cov -n 4
|
pdm run pytest --cov -n 4
|
||||||
# run docker desktop
|
# run docker desktop
|
||||||
. "C:\Program Files\Docker\Docker\Docker Desktop.exe"
|
. "C:\Program Files\Docker\Docker\Docker Desktop.exe"
|
||||||
@ -16,12 +16,15 @@ from pandas import DataFrame
|
|||||||
from lang_main.constants import (
|
from lang_main.constants import (
|
||||||
EDGE_WEIGHT_DECIMALS,
|
EDGE_WEIGHT_DECIMALS,
|
||||||
LOGGING_DEFAULT_GRAPHS,
|
LOGGING_DEFAULT_GRAPHS,
|
||||||
|
PROPERTY_NAME_BETWEENNESS_CENTRALITY,
|
||||||
PROPERTY_NAME_DEGREE_WEIGHTED,
|
PROPERTY_NAME_DEGREE_WEIGHTED,
|
||||||
|
PROPERTY_NAME_IMPORTANCE,
|
||||||
)
|
)
|
||||||
from lang_main.errors import (
|
from lang_main.errors import (
|
||||||
EdgePropertyNotContainedError,
|
EdgePropertyNotContainedError,
|
||||||
EmptyEdgesError,
|
EmptyEdgesError,
|
||||||
EmptyGraphError,
|
EmptyGraphError,
|
||||||
|
NodePropertyNotContainedError,
|
||||||
)
|
)
|
||||||
from lang_main.io import load_pickle, save_pickle
|
from lang_main.io import load_pickle, save_pickle
|
||||||
from lang_main.loggers import logger_graphs as logger
|
from lang_main.loggers import logger_graphs as logger
|
||||||
@ -310,15 +313,98 @@ def add_weighted_degree(
|
|||||||
property of the edges which contains the weight information, by default 'weight'
|
property of the edges which contains the weight information, by default 'weight'
|
||||||
property_name : str, optional
|
property_name : str, optional
|
||||||
target name for the property containing the weighted degree in nodes,
|
target name for the property containing the weighted degree in nodes,
|
||||||
by default 'degree_weighted'
|
by default PROPERTY_NAME_DEGREE_WEIGHTED
|
||||||
"""
|
"""
|
||||||
node_degree_mapping = cast(
|
node_property_mapping = cast(
|
||||||
dict[str, float],
|
dict[str, float],
|
||||||
dict(graph.degree(weight=edge_weight_property)), # type: ignore
|
dict(graph.degree(weight=edge_weight_property)), # type: ignore
|
||||||
)
|
)
|
||||||
nx.set_node_attributes(
|
nx.set_node_attributes(
|
||||||
graph,
|
graph,
|
||||||
node_degree_mapping,
|
node_property_mapping,
|
||||||
|
name=property_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def add_betweenness_centrality(
|
||||||
|
graph: DiGraph | Graph,
|
||||||
|
edge_weight_property: str | None = None,
|
||||||
|
property_name: str = PROPERTY_NAME_BETWEENNESS_CENTRALITY,
|
||||||
|
) -> None:
|
||||||
|
"""adds the betweenness centrality as property to each node of the given graph
|
||||||
|
Operation is performed inplace.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
graph : DiGraph | Graph
|
||||||
|
Graph with betweenness centrality as node property added inplace
|
||||||
|
edge_weight_property : str | None, optional
|
||||||
|
property of the edges which contains the weight information,
|
||||||
|
not necessarily needed, by default 'None'
|
||||||
|
property_name : str, optional
|
||||||
|
target name for the property containing the betweenness centrality in nodes,
|
||||||
|
by default PROPERTY_NAME_BETWEENNESS_CENTRALITY
|
||||||
|
"""
|
||||||
|
|
||||||
|
node_property_mapping = cast(
|
||||||
|
dict[str, float],
|
||||||
|
nx.betweenness_centrality(graph, normalized=True, weight=edge_weight_property), # type: ignore
|
||||||
|
)
|
||||||
|
nx.set_node_attributes(
|
||||||
|
graph,
|
||||||
|
node_property_mapping,
|
||||||
|
name=property_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def add_importance_metric(
|
||||||
|
graph: DiGraph | Graph,
|
||||||
|
property_name: str = PROPERTY_NAME_IMPORTANCE,
|
||||||
|
property_name_weighted_degree: str = PROPERTY_NAME_DEGREE_WEIGHTED,
|
||||||
|
property_name_betweenness: str = PROPERTY_NAME_BETWEENNESS_CENTRALITY,
|
||||||
|
) -> None:
|
||||||
|
"""Adds a custom importance metric as property to each node of the given graph.
|
||||||
|
Can be used to decide which nodes are of high importance and also to build node size
|
||||||
|
mappings.
|
||||||
|
Operation is performed inplace.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
graph : DiGraph | Graph
|
||||||
|
Graph with weighted degree as node property added inplace
|
||||||
|
property_name : str, optional
|
||||||
|
target name for the property containing the weighted degree in nodes,
|
||||||
|
by default PROPERTY_NAME_DEGREE_WEIGHTED
|
||||||
|
property_name_betweenness : str, optional
|
||||||
|
target name for the property containing the betweenness centrality in nodes,
|
||||||
|
by default PROPERTY_NAME_BETWEENNESS_CENTRALITY
|
||||||
|
"""
|
||||||
|
# build mapping for importance metric
|
||||||
|
node_property_mapping: dict[str, float] = {}
|
||||||
|
for node in cast(Iterable[str], graph.nodes):
|
||||||
|
node_data = cast(dict[str, float], graph.nodes[node])
|
||||||
|
|
||||||
|
if property_name_weighted_degree not in node_data:
|
||||||
|
raise NodePropertyNotContainedError(
|
||||||
|
(
|
||||||
|
f'Node data does not contain weighted degree '
|
||||||
|
f'with name {property_name_weighted_degree}.'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif property_name_betweenness not in node_data:
|
||||||
|
raise NodePropertyNotContainedError(
|
||||||
|
(
|
||||||
|
f'Node data does not contain betweenness centrality '
|
||||||
|
f'with name {property_name_betweenness}.'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
prio = node_data[property_name_weighted_degree] * node_data[property_name_betweenness]
|
||||||
|
node_property_mapping[node] = prio
|
||||||
|
|
||||||
|
nx.set_node_attributes(
|
||||||
|
graph,
|
||||||
|
node_property_mapping,
|
||||||
name=property_name,
|
name=property_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -351,6 +437,8 @@ def pipe_add_graph_metrics(
|
|||||||
for graph in graphs:
|
for graph in graphs:
|
||||||
graph_copy = copy.deepcopy(graph)
|
graph_copy = copy.deepcopy(graph)
|
||||||
add_weighted_degree(graph_copy)
|
add_weighted_degree(graph_copy)
|
||||||
|
add_betweenness_centrality(graph_copy)
|
||||||
|
add_importance_metric(graph_copy)
|
||||||
collection.append(graph_copy)
|
collection.append(graph_copy)
|
||||||
|
|
||||||
return tuple(collection)
|
return tuple(collection)
|
||||||
@ -762,19 +850,3 @@ class TokenGraph(DiGraph):
|
|||||||
raise ValueError('File format not supported.')
|
raise ValueError('File format not supported.')
|
||||||
|
|
||||||
return graph
|
return graph
|
||||||
|
|
||||||
# TODO check removal
|
|
||||||
# @classmethod
|
|
||||||
# def from_pickle(
|
|
||||||
# cls,
|
|
||||||
# path: str | Path,
|
|
||||||
# ) -> Self:
|
|
||||||
# if isinstance(path, str):
|
|
||||||
# path = Path(path)
|
|
||||||
|
|
||||||
# if path.suffix not in ('.pkl', '.pickle'):
|
|
||||||
# raise ValueError('File format not supported.')
|
|
||||||
|
|
||||||
# graph = typing.cast(Self, load_pickle(path))
|
|
||||||
|
|
||||||
# return graph
|
|
||||||
|
|||||||
@ -19,7 +19,7 @@ except ImportError:
|
|||||||
|
|
||||||
# ** external packages config
|
# ** external packages config
|
||||||
# ** Huggingface Hub caching
|
# ** Huggingface Hub caching
|
||||||
os.environ['HF_HUB_DISABLE_SYMLINKS_WARNING'] = 'set'
|
os.environ['HF_HUB_DISABLE_SYMLINKS_WARNING'] = '1'
|
||||||
|
|
||||||
# ** py4cytoscape config
|
# ** py4cytoscape config
|
||||||
if _has_py4cyto:
|
if _has_py4cyto:
|
||||||
@ -36,7 +36,7 @@ BASE_FOLDERNAME: Final[str] = os.environ.get('LANG_MAIN_BASE_FOLDERNAME', 'lang-
|
|||||||
CONFIG_FILENAME: Final[str] = 'lang_main_config.toml'
|
CONFIG_FILENAME: Final[str] = 'lang_main_config.toml'
|
||||||
CYTO_STYLESHEET_FILENAME: Final[str] = r'cytoscape_config/lang_main.xml'
|
CYTO_STYLESHEET_FILENAME: Final[str] = r'cytoscape_config/lang_main.xml'
|
||||||
PKG_DIR: Final[Path] = Path(__file__).parent
|
PKG_DIR: Final[Path] = Path(__file__).parent
|
||||||
STOP_FOLDER: Final[str] = 'python'
|
STOP_FOLDER: Final[str] = os.environ.get('LANG_MAIN_STOP_SEARCH_FOLDERNAME', 'src')
|
||||||
|
|
||||||
|
|
||||||
def load_toml_config(
|
def load_toml_config(
|
||||||
@ -65,6 +65,7 @@ def load_cfg(
|
|||||||
starting_path: Path,
|
starting_path: Path,
|
||||||
glob_pattern: str,
|
glob_pattern: str,
|
||||||
stop_folder_name: str | None,
|
stop_folder_name: str | None,
|
||||||
|
lookup_cwd: bool = False,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Look for configuration file. Internal configs are not used any more because
|
"""Look for configuration file. Internal configs are not used any more because
|
||||||
the library behaviour is only guaranteed by external configurations.
|
the library behaviour is only guaranteed by external configurations.
|
||||||
@ -91,7 +92,8 @@ def load_cfg(
|
|||||||
LangMainConfigNotFoundError
|
LangMainConfigNotFoundError
|
||||||
if no config file was found
|
if no config file was found
|
||||||
"""
|
"""
|
||||||
cfg_path: Path | None
|
cfg_path: Path | None = None
|
||||||
|
if lookup_cwd:
|
||||||
print('Looking for cfg file in CWD.', flush=True)
|
print('Looking for cfg file in CWD.', flush=True)
|
||||||
cfg_path = search_cwd(glob_pattern)
|
cfg_path = search_cwd(glob_pattern)
|
||||||
|
|
||||||
|
|||||||
@ -54,17 +54,13 @@ PICKLE_PROTOCOL_VERSION: Final[int] = 5
|
|||||||
# config placed in library path of application (usually "bin")
|
# config placed in library path of application (usually "bin")
|
||||||
input_path_cfg = LIB_PATH / Path(CONFIG['paths']['inputs'])
|
input_path_cfg = LIB_PATH / Path(CONFIG['paths']['inputs'])
|
||||||
INPUT_PATH_FOLDER: Final[Path] = input_path_cfg.resolve()
|
INPUT_PATH_FOLDER: Final[Path] = input_path_cfg.resolve()
|
||||||
# TODO reactivate later
|
if not INPUT_PATH_FOLDER.exists(): # pragma: no cover
|
||||||
if not INPUT_PATH_FOLDER.exists():
|
|
||||||
raise FileNotFoundError(f'Input path >>{INPUT_PATH_FOLDER}<< does not exist.')
|
raise FileNotFoundError(f'Input path >>{INPUT_PATH_FOLDER}<< does not exist.')
|
||||||
save_path_cfg = LIB_PATH / Path(CONFIG['paths']['results'])
|
save_path_cfg = LIB_PATH / Path(CONFIG['paths']['results'])
|
||||||
SAVE_PATH_FOLDER: Final[Path] = save_path_cfg.resolve()
|
SAVE_PATH_FOLDER: Final[Path] = save_path_cfg.resolve()
|
||||||
if not SAVE_PATH_FOLDER.exists():
|
if not SAVE_PATH_FOLDER.exists(): # pragma: no cover
|
||||||
raise FileNotFoundError(f'Output path >>{SAVE_PATH_FOLDER}<< does not exist.')
|
raise FileNotFoundError(f'Output path >>{SAVE_PATH_FOLDER}<< does not exist.')
|
||||||
path_dataset_cfg = LIB_PATH / Path(CONFIG['paths']['dataset'])
|
|
||||||
PATH_TO_DATASET: Final[Path] = path_dataset_cfg.resolve()
|
|
||||||
# if not PATH_TO_DATASET.exists():
|
|
||||||
# raise FileNotFoundError(f'Dataset path >>{PATH_TO_DATASET}<< does not exist.')
|
|
||||||
# ** control
|
# ** control
|
||||||
SKIP_PREPROCESSING: Final[bool] = CONFIG['control']['preprocessing_skip']
|
SKIP_PREPROCESSING: Final[bool] = CONFIG['control']['preprocessing_skip']
|
||||||
SKIP_TOKEN_ANALYSIS: Final[bool] = CONFIG['control']['token_analysis_skip']
|
SKIP_TOKEN_ANALYSIS: Final[bool] = CONFIG['control']['token_analysis_skip']
|
||||||
@ -82,22 +78,34 @@ MODEL_BASE_FOLDER: Final[Path] = model_folder_cfg.resolve()
|
|||||||
if not MODEL_BASE_FOLDER.exists():
|
if not MODEL_BASE_FOLDER.exists():
|
||||||
raise FileNotFoundError('Language model folder not found.')
|
raise FileNotFoundError('Language model folder not found.')
|
||||||
os.environ['SENTENCE_TRANSFORMERS_HOME'] = str(MODEL_BASE_FOLDER)
|
os.environ['SENTENCE_TRANSFORMERS_HOME'] = str(MODEL_BASE_FOLDER)
|
||||||
SPACY_MODEL_NAME: Final[SpacyModelTypes] = SpacyModelTypes.DE_CORE_NEWS_SM
|
|
||||||
STFR_MODEL_NAME: Final[STFRModelTypes] = STFRModelTypes.ALL_MPNET_BASE_V2
|
# LANG_MAIN_BASE_FOLDERNAME : base folder of library, not root (folder in which Python installation is found)
|
||||||
|
# LANG_MAIN_SPACY_MODEL : spaCy model used; if not provided, use constant value defined in library; more internal use
|
||||||
|
# LANG_MAIN_STFR_MODEL : Sentence Transformer model used; if not provided, use constant value defined in library; more internal use
|
||||||
|
# LANG_MAIN_STFR_BACKEND : STFR backend, choice between "torch" and "onnx"
|
||||||
|
|
||||||
|
SPACY_MODEL_NAME: Final[str | SpacyModelTypes] = os.environ.get(
|
||||||
|
'LANG_MAIN_SPACY_MODEL', SpacyModelTypes.DE_CORE_NEWS_SM
|
||||||
|
)
|
||||||
|
STFR_MODEL_NAME: Final[str | STFRModelTypes] = os.environ.get(
|
||||||
|
'LANG_MAIN_STFR_MODEL', STFRModelTypes.ALL_MPNET_BASE_V2
|
||||||
|
)
|
||||||
STFR_DEVICE: Final[STFRDeviceTypes] = STFRDeviceTypes.CPU
|
STFR_DEVICE: Final[STFRDeviceTypes] = STFRDeviceTypes.CPU
|
||||||
STFR_SIMILARITY: Final[SimilarityFunction] = SimilarityFunction.COSINE
|
STFR_SIMILARITY: Final[SimilarityFunction] = SimilarityFunction.COSINE
|
||||||
STFR_BACKEND: Final[STFRBackends] = STFRBackends.TORCH
|
STFR_BACKEND: Final[str | STFRBackends] = os.environ.get(
|
||||||
STFR_MODEL_ARGS_DEFAULT: STFRModelArgs = {}
|
'LANG_MAIN_STFR_BACKEND', STFRBackends.TORCH
|
||||||
STFR_MODEL_ARGS_ONNX: STFRModelArgs = {
|
)
|
||||||
|
stfr_model_args_default: STFRModelArgs = {}
|
||||||
|
stfr_model_args_onnx: STFRModelArgs = {
|
||||||
'file_name': STFRQuantFilenames.ONNX_Q_UINT8,
|
'file_name': STFRQuantFilenames.ONNX_Q_UINT8,
|
||||||
'provider': ONNXExecutionProvider.CPU,
|
'provider': ONNXExecutionProvider.CPU,
|
||||||
'export': False,
|
'export': False,
|
||||||
}
|
}
|
||||||
stfr_model_args: STFRModelArgs
|
stfr_model_args: STFRModelArgs
|
||||||
if STFR_BACKEND == STFRBackends.ONNX:
|
if STFR_BACKEND == STFRBackends.ONNX:
|
||||||
stfr_model_args = STFR_MODEL_ARGS_ONNX
|
stfr_model_args = stfr_model_args_onnx
|
||||||
else:
|
else:
|
||||||
stfr_model_args = STFR_MODEL_ARGS_DEFAULT
|
stfr_model_args = stfr_model_args_default
|
||||||
|
|
||||||
STFR_MODEL_ARGS: Final[STFRModelArgs] = stfr_model_args
|
STFR_MODEL_ARGS: Final[STFRModelArgs] = stfr_model_args
|
||||||
# ** language dependency analysis
|
# ** language dependency analysis
|
||||||
@ -122,6 +130,8 @@ THRESHOLD_SIMILARITY: Final[float] = CONFIG['preprocess']['threshold_similarity'
|
|||||||
EDGE_WEIGHT_DECIMALS: Final[int] = 4
|
EDGE_WEIGHT_DECIMALS: Final[int] = 4
|
||||||
THRESHOLD_EDGE_NUMBER: Final[int] = CONFIG['graph_postprocessing']['threshold_edge_number']
|
THRESHOLD_EDGE_NUMBER: Final[int] = CONFIG['graph_postprocessing']['threshold_edge_number']
|
||||||
PROPERTY_NAME_DEGREE_WEIGHTED: Final[str] = 'degree_weighted'
|
PROPERTY_NAME_DEGREE_WEIGHTED: Final[str] = 'degree_weighted'
|
||||||
|
PROPERTY_NAME_BETWEENNESS_CENTRALITY: Final[str] = 'betweenness_centrality'
|
||||||
|
PROPERTY_NAME_IMPORTANCE: Final[str] = 'importance'
|
||||||
|
|
||||||
# ** graph exports (Cytoscape)
|
# ** graph exports (Cytoscape)
|
||||||
CYTO_MAX_NODE_COUNT: Final[int] = 500
|
CYTO_MAX_NODE_COUNT: Final[int] = 500
|
||||||
|
|||||||
6
src/lang_main/env_vars.txt
Normal file
6
src/lang_main/env_vars.txt
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
# list of all library's environment variables
|
||||||
|
LANG_MAIN_STOP_SEARCH_FOLDERNAME : foldername in package directory tree at which the lookup should stop; used to find directory root
|
||||||
|
LANG_MAIN_BASE_FOLDERNAME : base folder of library, not root (folder in which Python installation is found)
|
||||||
|
LANG_MAIN_SPACY_MODEL : spaCy model used; if not provided, use constant value defined in library; more internal use
|
||||||
|
LANG_MAIN_STFR_MODEL : Sentence Transformer model used; if not provided, use constant value defined in library; more internal use
|
||||||
|
LANG_MAIN_STFR_BACKEND : STFR backend, choice between "torch" and "onnx"
|
||||||
@ -8,6 +8,10 @@ class LanguageModelNotFoundError(Exception):
|
|||||||
|
|
||||||
|
|
||||||
# ** token graph exceptions
|
# ** token graph exceptions
|
||||||
|
class NodePropertyNotContainedError(Exception):
|
||||||
|
"""Error raised if a needed node property is not contained in graph edges"""
|
||||||
|
|
||||||
|
|
||||||
class EdgePropertyNotContainedError(Exception):
|
class EdgePropertyNotContainedError(Exception):
|
||||||
"""Error raised if a needed edge property is not contained in graph edges"""
|
"""Error raised if a needed edge property is not contained in graph edges"""
|
||||||
|
|
||||||
|
|||||||
@ -1,22 +1,18 @@
|
|||||||
# lang_main: Config file
|
# lang_main: Config file
|
||||||
[info]
|
|
||||||
pkg = 'lang_main_internal'
|
|
||||||
|
|
||||||
[paths]
|
[paths]
|
||||||
inputs = './data/in/'
|
inputs = '../data/in/'
|
||||||
# results = './results/dummy_N_1000/'
|
# results = './results/dummy_N_1000/'
|
||||||
# dataset = '../data/Dummy_Dataset_N_1000.csv'
|
# dataset = '../data/Dummy_Dataset_N_1000.csv'
|
||||||
results = './data/out/'
|
results = '../data/out/'
|
||||||
dataset = '../data/02_202307/Export4.csv'
|
models = './lang-models'
|
||||||
models = '../../lang-models'
|
|
||||||
|
|
||||||
[logging]
|
[logging]
|
||||||
enabled = true
|
enabled = true
|
||||||
stderr = true
|
stderr = true
|
||||||
file = true
|
file = true
|
||||||
|
|
||||||
# only debugging features, production-ready pipelines should always
|
# control which pipelines are executed
|
||||||
# be fully executed
|
|
||||||
[control]
|
[control]
|
||||||
preprocessing_skip = false
|
preprocessing_skip = false
|
||||||
token_analysis_skip = false
|
token_analysis_skip = false
|
||||||
|
|||||||
@ -33,7 +33,7 @@ if ENABLE_LOGGING and LOGGING_TO_STDERR:
|
|||||||
logger_all_handler_stderr = logging.StreamHandler()
|
logger_all_handler_stderr = logging.StreamHandler()
|
||||||
logger_all_handler_stderr.setLevel(LOGGING_LEVEL_STDERR)
|
logger_all_handler_stderr.setLevel(LOGGING_LEVEL_STDERR)
|
||||||
logger_all_handler_stderr.setFormatter(logger_all_formater)
|
logger_all_handler_stderr.setFormatter(logger_all_formater)
|
||||||
else:
|
else: # pragma: no cover
|
||||||
logger_all_handler_stderr = null_handler
|
logger_all_handler_stderr = null_handler
|
||||||
|
|
||||||
if ENABLE_LOGGING and LOGGING_TO_FILE:
|
if ENABLE_LOGGING and LOGGING_TO_FILE:
|
||||||
@ -45,7 +45,7 @@ if ENABLE_LOGGING and LOGGING_TO_FILE:
|
|||||||
)
|
)
|
||||||
logger_all_handler_file.setLevel(LOGGING_LEVEL_FILE)
|
logger_all_handler_file.setLevel(LOGGING_LEVEL_FILE)
|
||||||
logger_all_handler_file.setFormatter(logger_all_formater)
|
logger_all_handler_file.setFormatter(logger_all_formater)
|
||||||
else:
|
else: # pragma: no cover
|
||||||
logger_all_handler_file = null_handler
|
logger_all_handler_file = null_handler
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -33,6 +33,7 @@ class BasePipeline(ABC):
|
|||||||
# container for actions to perform during pass
|
# container for actions to perform during pass
|
||||||
self.actions: list[Callable] = []
|
self.actions: list[Callable] = []
|
||||||
self.action_names: list[str] = []
|
self.action_names: list[str] = []
|
||||||
|
self.action_skip: list[bool] = []
|
||||||
# progress tracking, start at 1
|
# progress tracking, start at 1
|
||||||
self.curr_proc_idx: int = 1
|
self.curr_proc_idx: int = 1
|
||||||
|
|
||||||
@ -104,8 +105,6 @@ class PipelineContainer(BasePipeline):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(name=name, working_dir=working_dir)
|
super().__init__(name=name, working_dir=working_dir)
|
||||||
|
|
||||||
self.action_skip: list[bool] = []
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def add(
|
def add(
|
||||||
self,
|
self,
|
||||||
@ -170,6 +169,7 @@ class Pipeline(BasePipeline):
|
|||||||
self,
|
self,
|
||||||
action: Callable,
|
action: Callable,
|
||||||
action_kwargs: dict[str, Any] | None = None,
|
action_kwargs: dict[str, Any] | None = None,
|
||||||
|
skip: bool = False,
|
||||||
save_result: bool = False,
|
save_result: bool = False,
|
||||||
load_result: bool = False,
|
load_result: bool = False,
|
||||||
filename: str | None = None,
|
filename: str | None = None,
|
||||||
@ -183,6 +183,7 @@ class Pipeline(BasePipeline):
|
|||||||
self.actions.append(action)
|
self.actions.append(action)
|
||||||
self.action_names.append(action.__name__)
|
self.action_names.append(action.__name__)
|
||||||
self.actions_kwargs.append(action_kwargs.copy())
|
self.actions_kwargs.append(action_kwargs.copy())
|
||||||
|
self.action_skip.append(skip)
|
||||||
self.save_results.append((save_result, filename))
|
self.save_results.append((save_result, filename))
|
||||||
self.load_results.append((load_result, filename))
|
self.load_results.append((load_result, filename))
|
||||||
else:
|
else:
|
||||||
@ -235,7 +236,13 @@ class Pipeline(BasePipeline):
|
|||||||
self,
|
self,
|
||||||
starting_values: tuple[Any, ...] | None = None,
|
starting_values: tuple[Any, ...] | None = None,
|
||||||
) -> tuple[Any, ...]:
|
) -> tuple[Any, ...]:
|
||||||
|
first_performed: bool = False
|
||||||
|
|
||||||
for idx, (action, action_kwargs) in enumerate(zip(self.actions, self.actions_kwargs)):
|
for idx, (action, action_kwargs) in enumerate(zip(self.actions, self.actions_kwargs)):
|
||||||
|
if self.action_skip[idx]:
|
||||||
|
self.curr_proc_idx += 1
|
||||||
|
continue
|
||||||
|
|
||||||
# loading
|
# loading
|
||||||
if self.load_results[idx][0]:
|
if self.load_results[idx][0]:
|
||||||
filename = self.load_results[idx][1]
|
filename = self.load_results[idx][1]
|
||||||
@ -248,8 +255,9 @@ class Pipeline(BasePipeline):
|
|||||||
self.curr_proc_idx += 1
|
self.curr_proc_idx += 1
|
||||||
continue
|
continue
|
||||||
# calculation
|
# calculation
|
||||||
if idx == 0:
|
if not first_performed:
|
||||||
args = starting_values
|
args = starting_values
|
||||||
|
first_performed = True
|
||||||
else:
|
else:
|
||||||
args = ret
|
args = ret
|
||||||
|
|
||||||
|
|||||||
@ -296,7 +296,7 @@ def apply_style_to_network(
|
|||||||
style_name: str = CYTO_STYLESHEET_NAME,
|
style_name: str = CYTO_STYLESHEET_NAME,
|
||||||
pth_to_stylesheet: Path = CYTO_PATH_STYLESHEET,
|
pth_to_stylesheet: Path = CYTO_PATH_STYLESHEET,
|
||||||
network_name: str = CYTO_BASE_NETWORK_NAME,
|
network_name: str = CYTO_BASE_NETWORK_NAME,
|
||||||
node_size_property: str = 'node_selection',
|
node_size_property: str = CYTO_SELECTION_PROPERTY,
|
||||||
min_node_size: int = 15,
|
min_node_size: int = 15,
|
||||||
max_node_size: int = 40,
|
max_node_size: int = 40,
|
||||||
sandbox_name: str = CYTO_SANDBOX_NAME,
|
sandbox_name: str = CYTO_SANDBOX_NAME,
|
||||||
|
|||||||
@ -2,7 +2,12 @@ import networkx as nx
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from lang_main.analysis import graphs
|
from lang_main.analysis import graphs
|
||||||
from lang_main.errors import EmptyEdgesError, EmptyGraphError, EdgePropertyNotContainedError
|
from lang_main.errors import (
|
||||||
|
EdgePropertyNotContainedError,
|
||||||
|
EmptyEdgesError,
|
||||||
|
EmptyGraphError,
|
||||||
|
NodePropertyNotContainedError,
|
||||||
|
)
|
||||||
|
|
||||||
TK_GRAPH_NAME = 'TEST_TOKEN_GRAPH'
|
TK_GRAPH_NAME = 'TEST_TOKEN_GRAPH'
|
||||||
|
|
||||||
@ -231,6 +236,49 @@ def test_add_weighted_degree():
|
|||||||
assert graph_obj.nodes[3][property_name] == 6
|
assert graph_obj.nodes[3][property_name] == 6
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_betweenness_centrality():
|
||||||
|
graph_obj = build_init_graph(token_graph=False)
|
||||||
|
property_name = 'betweenness_centrality'
|
||||||
|
graphs.add_betweenness_centrality(graph_obj, property_name=property_name)
|
||||||
|
assert round(graph_obj.nodes[1][property_name], 4) == pytest.approx(0.1667)
|
||||||
|
assert graph_obj.nodes[2][property_name] == 0
|
||||||
|
assert graph_obj.nodes[3][property_name] == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_importance_metric():
|
||||||
|
graph_obj = build_init_graph(token_graph=False)
|
||||||
|
property_name_WD = 'degree_weighted'
|
||||||
|
graphs.add_weighted_degree(graph_obj, 'weight', property_name_WD)
|
||||||
|
property_name_BC = 'betweenness_centrality'
|
||||||
|
graphs.add_betweenness_centrality(graph_obj, property_name=property_name_BC)
|
||||||
|
property_name = 'importance'
|
||||||
|
graphs.add_importance_metric(
|
||||||
|
graph_obj,
|
||||||
|
property_name=property_name,
|
||||||
|
property_name_weighted_degree=property_name_WD,
|
||||||
|
property_name_betweenness=property_name_BC,
|
||||||
|
)
|
||||||
|
assert round(graph_obj.nodes[1][property_name], 4) == pytest.approx(2.3333)
|
||||||
|
assert graph_obj.nodes[2][property_name] == 0
|
||||||
|
assert graph_obj.nodes[3][property_name] == 0
|
||||||
|
|
||||||
|
with pytest.raises(NodePropertyNotContainedError):
|
||||||
|
graphs.add_importance_metric(
|
||||||
|
graph_obj,
|
||||||
|
property_name=property_name,
|
||||||
|
property_name_weighted_degree='prop_not_contained',
|
||||||
|
property_name_betweenness=property_name_BC,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(NodePropertyNotContainedError):
|
||||||
|
graphs.add_importance_metric(
|
||||||
|
graph_obj,
|
||||||
|
property_name=property_name,
|
||||||
|
property_name_weighted_degree=property_name_WD,
|
||||||
|
property_name_betweenness='prop_not_contained',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_static_graph_analysis():
|
def test_static_graph_analysis():
|
||||||
graph_obj = build_init_graph(token_graph=True)
|
graph_obj = build_init_graph(token_graph=True)
|
||||||
(graph_obj,) = graphs.static_graph_analysis(graph_obj) # type: ignore
|
(graph_obj,) = graphs.static_graph_analysis(graph_obj) # type: ignore
|
||||||
@ -254,6 +302,20 @@ def test_pipe_add_graph_metrics():
|
|||||||
assert graph_collection[1].nodes[1][property_name] == 14
|
assert graph_collection[1].nodes[1][property_name] == 14
|
||||||
assert graph_collection[1].nodes[2][property_name] == 10
|
assert graph_collection[1].nodes[2][property_name] == 10
|
||||||
assert graph_collection[1].nodes[3][property_name] == 6
|
assert graph_collection[1].nodes[3][property_name] == 6
|
||||||
|
property_name = 'betweenness_centrality'
|
||||||
|
assert round(graph_collection[0].nodes[1][property_name], 4) == pytest.approx(0.1667)
|
||||||
|
assert graph_collection[0].nodes[2][property_name] == 0
|
||||||
|
assert graph_collection[0].nodes[3][property_name] == 0
|
||||||
|
assert round(graph_collection[1].nodes[1][property_name], 4) == pytest.approx(0.1667)
|
||||||
|
assert graph_collection[1].nodes[2][property_name] == 0
|
||||||
|
assert graph_collection[1].nodes[3][property_name] == 0
|
||||||
|
property_name = 'importance'
|
||||||
|
assert round(graph_collection[0].nodes[1][property_name], 4) == pytest.approx(2.3333)
|
||||||
|
assert graph_collection[0].nodes[2][property_name] == 0
|
||||||
|
assert graph_collection[0].nodes[3][property_name] == 0
|
||||||
|
assert round(graph_collection[1].nodes[1][property_name], 4) == pytest.approx(2.3333)
|
||||||
|
assert graph_collection[1].nodes[2][property_name] == 0
|
||||||
|
assert graph_collection[1].nodes[3][property_name] == 0
|
||||||
|
|
||||||
|
|
||||||
def test_pipe_rescale_graph_edge_weights(tk_graph):
|
def test_pipe_rescale_graph_edge_weights(tk_graph):
|
||||||
|
|||||||
@ -121,6 +121,7 @@ def test_pipeline_valid(pipeline, alter_content):
|
|||||||
assert len(pipe.actions) == 1
|
assert len(pipe.actions) == 1
|
||||||
assert len(pipe.action_names) == 1
|
assert len(pipe.action_names) == 1
|
||||||
assert len(pipe.actions_kwargs) == 1
|
assert len(pipe.actions_kwargs) == 1
|
||||||
|
assert len(pipe.action_skip) == 1
|
||||||
assert len(pipe.save_results) == 1
|
assert len(pipe.save_results) == 1
|
||||||
assert len(pipe.load_results) == 1
|
assert len(pipe.load_results) == 1
|
||||||
assert pipe.save_results[0] == (False, None)
|
assert pipe.save_results[0] == (False, None)
|
||||||
@ -166,7 +167,7 @@ def test_pipeline_valid_action_load(pipeline, working_dir):
|
|||||||
test_string = 'test'
|
test_string = 'test'
|
||||||
|
|
||||||
# action preparation
|
# action preparation
|
||||||
def valid_action(string, add_content=False):
|
def valid_action(string, add_content=False): # pragma: no cover
|
||||||
if add_content:
|
if add_content:
|
||||||
string += '_2'
|
string += '_2'
|
||||||
return string
|
return string
|
||||||
@ -175,6 +176,7 @@ def test_pipeline_valid_action_load(pipeline, working_dir):
|
|||||||
assert len(pipe.actions) == 1
|
assert len(pipe.actions) == 1
|
||||||
assert len(pipe.action_names) == 1
|
assert len(pipe.action_names) == 1
|
||||||
assert len(pipe.actions_kwargs) == 1
|
assert len(pipe.actions_kwargs) == 1
|
||||||
|
assert len(pipe.action_skip) == 1
|
||||||
assert len(pipe.save_results) == 1
|
assert len(pipe.save_results) == 1
|
||||||
assert len(pipe.load_results) == 1
|
assert len(pipe.load_results) == 1
|
||||||
assert pipe.save_results[0] == (False, None)
|
assert pipe.save_results[0] == (False, None)
|
||||||
@ -209,19 +211,28 @@ def test_pipeline_multiple_actions(pipeline):
|
|||||||
string += '_3'
|
string += '_3'
|
||||||
return string
|
return string
|
||||||
|
|
||||||
|
pipe.add(valid_action, {'add_content': True}, skip=True)
|
||||||
pipe.add(valid_action, {'add_content': True})
|
pipe.add(valid_action, {'add_content': True})
|
||||||
pipe.add(valid_action_2)
|
pipe.add(valid_action_2)
|
||||||
assert len(pipe.actions) == 2
|
assert len(pipe.actions) == 3
|
||||||
assert len(pipe.action_names) == 2
|
assert len(pipe.action_names) == 3
|
||||||
assert len(pipe.actions_kwargs) == 2
|
assert len(pipe.actions_kwargs) == 3
|
||||||
assert len(pipe.save_results) == 2
|
assert len(pipe.action_skip) == 3
|
||||||
assert len(pipe.load_results) == 2
|
assert len(pipe.save_results) == 3
|
||||||
|
assert len(pipe.load_results) == 3
|
||||||
assert pipe.save_results[1] == (False, None)
|
assert pipe.save_results[1] == (False, None)
|
||||||
assert pipe.load_results[1] == (False, None)
|
assert pipe.load_results[1] == (False, None)
|
||||||
|
|
||||||
ret = pipe.run(starting_values=(test_string,))
|
ret = pipe.run(starting_values=(test_string,))
|
||||||
assert isinstance(ret, tuple)
|
assert isinstance(ret, tuple)
|
||||||
assert pipe._intermediate_result == ret
|
assert pipe._intermediate_result == ret
|
||||||
assert pipe.curr_proc_idx == 3
|
assert pipe.curr_proc_idx == 4
|
||||||
assert ret is not None
|
assert ret is not None
|
||||||
assert ret[0] == 'test_2_3'
|
assert ret[0] == 'test_2_3'
|
||||||
|
|
||||||
|
|
||||||
|
def test_pipeline_invalid_action(pipeline):
|
||||||
|
test_string = 'test'
|
||||||
|
|
||||||
|
with pytest.raises(WrongActionTypeError):
|
||||||
|
pipeline.add(test_string, skip=False)
|
||||||
|
|||||||
54
tests/render/test_import.py
Normal file
54
tests/render/test_import.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
import builtins
|
||||||
|
import importlib.util
|
||||||
|
from importlib import reload
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from lang_main.errors import DependencyMissingError
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope='function')
|
||||||
|
def no_dep(monkeypatch):
|
||||||
|
import_orig = builtins.__import__
|
||||||
|
|
||||||
|
def mocked_import(name, globals, locals, fromlist, level):
|
||||||
|
if name == 'py4cytoscape':
|
||||||
|
raise ImportError()
|
||||||
|
return import_orig(name, locals, fromlist, level)
|
||||||
|
|
||||||
|
monkeypatch.setattr(builtins, '__import__', mocked_import)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope='function')
|
||||||
|
def patch_find_spec(monkeypatch):
|
||||||
|
find_spec_orig = importlib.util.find_spec
|
||||||
|
|
||||||
|
def mocked_find_spec(*args, **kwargs):
|
||||||
|
if args[0] == 'py4cytoscape':
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return find_spec_orig(*args, **kwargs)
|
||||||
|
|
||||||
|
monkeypatch.setattr(importlib.util, 'find_spec', mocked_find_spec)
|
||||||
|
|
||||||
|
|
||||||
|
def test_p4c_available():
|
||||||
|
import lang_main.constants
|
||||||
|
|
||||||
|
reload(lang_main.constants)
|
||||||
|
|
||||||
|
assert lang_main.constants.Dependencies.PY4C.value
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures('patch_find_spec')
|
||||||
|
def test_p4c_missing(monkeypatch):
|
||||||
|
import lang_main.constants
|
||||||
|
|
||||||
|
reload(lang_main.constants)
|
||||||
|
|
||||||
|
assert not lang_main.constants.Dependencies.PY4C.value
|
||||||
|
|
||||||
|
with pytest.raises(DependencyMissingError):
|
||||||
|
from lang_main import render
|
||||||
|
|
||||||
|
reload(render)
|
||||||
@ -16,7 +16,7 @@ def test_p4c_dependency():
|
|||||||
def test_load_config():
|
def test_load_config():
|
||||||
toml_path = config.PKG_DIR / 'lang_main_config.toml'
|
toml_path = config.PKG_DIR / 'lang_main_config.toml'
|
||||||
loaded_cfg = config.load_toml_config(toml_path)
|
loaded_cfg = config.load_toml_config(toml_path)
|
||||||
assert loaded_cfg['info']['pkg'] == 'lang_main_internal'
|
assert loaded_cfg['paths']['models'] == './lang-models'
|
||||||
|
|
||||||
|
|
||||||
def test_get_config_path():
|
def test_get_config_path():
|
||||||
@ -36,7 +36,7 @@ def test_get_config_path():
|
|||||||
assert cyto_internal == cyto_cfg_pth
|
assert cyto_internal == cyto_cfg_pth
|
||||||
|
|
||||||
|
|
||||||
def test_load_cfg(monkeypatch, tmp_path):
|
def test_load_cfg_func(monkeypatch, tmp_path):
|
||||||
monkeypatch.setattr(Path, 'cwd', lambda: tmp_path)
|
monkeypatch.setattr(Path, 'cwd', lambda: tmp_path)
|
||||||
pkg_dir = config.PKG_DIR
|
pkg_dir = config.PKG_DIR
|
||||||
filename = config.CONFIG_FILENAME
|
filename = config.CONFIG_FILENAME
|
||||||
@ -44,21 +44,20 @@ def test_load_cfg(monkeypatch, tmp_path):
|
|||||||
|
|
||||||
cfg_pth_internal = (pkg_dir / filename).resolve()
|
cfg_pth_internal = (pkg_dir / filename).resolve()
|
||||||
ref_config = config.load_toml_config(cfg_pth_internal)
|
ref_config = config.load_toml_config(cfg_pth_internal)
|
||||||
|
assert ref_config['paths']['models'] == './lang-models'
|
||||||
|
|
||||||
assert ref_config['info']['pkg'] == 'lang_main_internal'
|
|
||||||
loaded_cfg = config.load_cfg(
|
loaded_cfg = config.load_cfg(
|
||||||
starting_path=pkg_dir,
|
starting_path=pkg_dir,
|
||||||
glob_pattern=filename,
|
glob_pattern=filename,
|
||||||
stop_folder_name=stop_folder,
|
stop_folder_name=stop_folder,
|
||||||
cfg_path_internal=cfg_pth_internal,
|
lookup_cwd=False,
|
||||||
prefer_internal_config=True,
|
|
||||||
)
|
)
|
||||||
assert loaded_cfg['info']['pkg'] == 'lang_main_internal'
|
assert loaded_cfg['paths']['models'] == '../lang-models'
|
||||||
|
|
||||||
loaded_cfg = config.load_cfg(
|
loaded_cfg = config.load_cfg(
|
||||||
starting_path=pkg_dir,
|
starting_path=pkg_dir,
|
||||||
glob_pattern=filename,
|
glob_pattern=filename,
|
||||||
stop_folder_name=stop_folder,
|
stop_folder_name=stop_folder,
|
||||||
cfg_path_internal=cfg_pth_internal,
|
lookup_cwd=True,
|
||||||
prefer_internal_config=False,
|
|
||||||
)
|
)
|
||||||
assert loaded_cfg['info']['pkg'] == 'lang_main'
|
assert loaded_cfg['paths']['models'] == '../lang-models'
|
||||||
|
|||||||
@ -4,12 +4,12 @@ from spacy.language import Language
|
|||||||
|
|
||||||
from lang_main import model_loader
|
from lang_main import model_loader
|
||||||
from lang_main.constants import (
|
from lang_main.constants import (
|
||||||
STFR_MODEL_ARGS_ONNX,
|
|
||||||
SimilarityFunction,
|
SimilarityFunction,
|
||||||
SpacyModelTypes,
|
SpacyModelTypes,
|
||||||
STFRBackends,
|
STFRBackends,
|
||||||
STFRDeviceTypes,
|
STFRDeviceTypes,
|
||||||
STFRModelTypes,
|
STFRModelTypes,
|
||||||
|
stfr_model_args_onnx,
|
||||||
)
|
)
|
||||||
from lang_main.errors import LanguageModelNotFoundError
|
from lang_main.errors import LanguageModelNotFoundError
|
||||||
from lang_main.types import LanguageModels
|
from lang_main.types import LanguageModels
|
||||||
@ -69,7 +69,7 @@ def test_load_sentence_transformer_onnx(model_name, similarity_func) -> None:
|
|||||||
similarity_func=similarity_func,
|
similarity_func=similarity_func,
|
||||||
backend=STFRBackends.ONNX,
|
backend=STFRBackends.ONNX,
|
||||||
device=STFRDeviceTypes.CPU,
|
device=STFRDeviceTypes.CPU,
|
||||||
model_kwargs=STFR_MODEL_ARGS_ONNX, # type: ignore
|
model_kwargs=stfr_model_args_onnx, # type: ignore
|
||||||
)
|
)
|
||||||
assert isinstance(model, SentenceTransformer)
|
assert isinstance(model, SentenceTransformer)
|
||||||
|
|
||||||
|
|||||||
Binary file not shown.
BIN
tests/work_dir/Pipe-test_Step-3_valid_action_2.pkl
Normal file
BIN
tests/work_dir/Pipe-test_Step-3_valid_action_2.pkl
Normal file
Binary file not shown.
Loading…
x
Reference in New Issue
Block a user