diff --git a/notebooks/misc.ipynb b/notebooks/misc.ipynb
index 5b47137..b354b50 100644
--- a/notebooks/misc.ipynb
+++ b/notebooks/misc.ipynb
@@ -21,17 +21,26 @@
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": 5,
"id": "c0dab307-2c2c-41d2-9867-ec9ba82a8099",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Loaded TOML config file successfully.\n"
+ ]
+ }
+ ],
"source": [
- "import networkx as nx"
+ "import networkx as nx\n",
+ "from lang_main.analysis import graphs"
]
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 17,
"id": "629f2051-7ef0-4ce0-a5ad-86b292cc20af",
"metadata": {},
"outputs": [],
@@ -56,7 +65,7 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 18,
"id": "c4fd9997-1e41-49f1-b879-4b3a6571931d",
"metadata": {},
"outputs": [],
@@ -70,7 +79,7 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 19,
"id": "bdf1c8d2-1093-420e-91fa-e2edd0cd72f1",
"metadata": {},
"outputs": [
@@ -85,7 +94,7 @@
" (2, 1, {'weight': 6})]"
]
},
- "execution_count": 4,
+ "execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
@@ -96,7 +105,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 68,
"id": "d017b2bc-9cd3-4124-afed-c6eabc07a540",
"metadata": {},
"outputs": [],
@@ -105,9 +114,582 @@
"G.add_edges_from(edges_to_add)"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": 69,
+ "id": "f8bbf276-3b07-41d6-ad74-778f09cbab96",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "graphs.add_weighted_degree(G, 'weight', 'degree_weighted')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 71,
+ "id": "d7b6f917-23f6-44a4-bc8d-125f7658e4d5",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "OutEdgeView([(1, 2), (1, 3), (1, 4), (2, 4), (2, 1), (3, 4)])"
+ ]
+ },
+ "execution_count": 71,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "G.edges"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "473e9e25-d417-4a0a-bff2-7765de516a89",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 72,
+ "id": "0a48d11d-1f2b-475e-9ddf-bb9a3f67accb",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from pathlib import Path"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 73,
+ "id": "e340377a-0df4-44ca-b18e-8b354e273eb9",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "save_pth = Path.cwd() / 'test.graphml'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 82,
+ "id": "66677ad0-a1e5-4772-a0ba-7fbeeda55297",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "nx.write_graphml(G, save_pth)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 84,
+ "id": "f01ebe25-56b9-410a-a2bf-d5a6e211de7a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "G_load = nx.read_graphml(save_pth, node_type=int)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 85,
+ "id": "10bfad35-1f96-41a1-9014-578313502e6c",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "OutEdgeView([(1, 2), (1, 3), (1, 4), (2, 4), (2, 1), (3, 4)])"
+ ]
+ },
+ "execution_count": 85,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "G_load.edges"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "66189241-637e-4765-b6f0-6ff090b6ba0a",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "a1af4ba3-ced8-425f-a730-da14fd8aab8e",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "1efd5f4e-fd19-46fd-bb7e-b23bec724cdd",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from lang_main.pipelines.predefined import STFR_MODEL"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "50ee13e1-e10e-4efe-8706-6ca321f6cf9a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "sents = [\n",
+ " 'Kontrolle der Schmiernippel',\n",
+ " 'Kontrolle der Schmiersysteme',\n",
+ "]\n",
+ "'Kontrolle der Lichtschranken\n",
+ "Überprüfung der Spannrollen\n",
+ "Überprüfung der Druckventile\n",
+ "Kontrolle der Schmiernippel\n",
+ "Kontrolle der Schmiersysteme\n",
+ "Inspektion der Förderbänder\n",
+ "Reinigung der Luftfilter\n",
+ "Inspektion der Schutzabdeckungen\n",
+ "Überprüfung der Ölstände\n",
+ "'Überprüfung der Hydraulik'\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "ca0b4089-d8cc-4566-a9ef-ed35b55d18b0",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "embds = STFR_MODEL.encode(sents)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "cff09ea6-04b9-4544-aee5-0a7e0bbda2d2",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[1.0000, 0.8907],\n",
+ " [0.8907, 1.0000]])"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "STFR_MODEL.similarity(embds, embds)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "45dc7050-9b6e-4c62-ba87-a74fb7985933",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "384"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "STFR_MODEL.max_seq_length"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "54bf2e2a-7ada-4e4d-9e2e-1d17631e7d06",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
{
"cell_type": "code",
"execution_count": 8,
+ "id": "c5d970e6-7bfd-4da0-82da-56a12e12a86c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pandas as pd"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "7dcf9e86-a7d3-436c-a705-cddb83e704bd",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "data = {\n",
+ " 'idx': [0,1,2,3,4],\n",
+ " 'data': ['test1', 'test2', 'test3', 'test4', 'test5']\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "0962d3af-e44d-4078-ac4f-dbd59e6a33eb",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "df1 = pd.DataFrame.from_dict(data)\n",
+ "df2 = pd.DataFrame.from_dict(data)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "id": "5743636e-0330-4c7b-879b-0aa8ff6bfa53",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "True"
+ ]
+ },
+ "execution_count": 22,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "bool((df1 == df2).all(axis=None))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "id": "d88b4e70-012e-4dfe-ad52-4210386ed8fd",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from pathlib import Path"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "id": "4afe4713-20d5-4626-a942-e28c4eff8d0a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "p = Path(r'A:\\Arbeitsaufgaben\\lang-main\\tests\\_comparison_results')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "id": "da810b1b-b5cf-4c18-ad26-eff156ccfd54",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "p_load = p / 'merge_similarity_candidates.pkl'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "id": "aa5774ff-5be3-4a7a-92dc-09331f12ee2d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "df1 = pd.read_pickle(p_load)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "id": "e4ec576e-ec39-4981-99e7-75fdc7ac0979",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "df2 = pd.read_pickle(p_load)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 38,
+ "id": "dc24c3a0-484b-4019-8f2c-4913e36d9b1b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "df1_c = df1[['entry', 'len', 'num_occur', 'num_assoc_obj_ids']]\n",
+ "df2_c = df2[['entry', 'len', 'num_occur', 'num_assoc_obj_ids']]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 39,
+ "id": "83ade5ae-95f7-4f44-afb2-e1c5a2c5694c",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " entry | \n",
+ " len | \n",
+ " num_occur | \n",
+ " num_assoc_obj_ids | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 41 | \n",
+ " True | \n",
+ " True | \n",
+ " True | \n",
+ " True | \n",
+ "
\n",
+ " \n",
+ " | 22 | \n",
+ " True | \n",
+ " True | \n",
+ " True | \n",
+ " True | \n",
+ "
\n",
+ " \n",
+ " | 13 | \n",
+ " True | \n",
+ " True | \n",
+ " True | \n",
+ " True | \n",
+ "
\n",
+ " \n",
+ " | 6 | \n",
+ " True | \n",
+ " True | \n",
+ " True | \n",
+ " True | \n",
+ "
\n",
+ " \n",
+ " | 29 | \n",
+ " True | \n",
+ " True | \n",
+ " True | \n",
+ " True | \n",
+ "
\n",
+ " \n",
+ " | 10 | \n",
+ " True | \n",
+ " True | \n",
+ " True | \n",
+ " True | \n",
+ "
\n",
+ " \n",
+ " | 17 | \n",
+ " True | \n",
+ " True | \n",
+ " True | \n",
+ " True | \n",
+ "
\n",
+ " \n",
+ " | 61 | \n",
+ " True | \n",
+ " True | \n",
+ " True | \n",
+ " True | \n",
+ "
\n",
+ " \n",
+ " | 5 | \n",
+ " True | \n",
+ " True | \n",
+ " True | \n",
+ " True | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " entry len num_occur num_assoc_obj_ids\n",
+ "41 True True True True\n",
+ "22 True True True True\n",
+ "13 True True True True\n",
+ "6 True True True True\n",
+ "29 True True True True\n",
+ "10 True True True True\n",
+ "17 True True True True\n",
+ "61 True True True True\n",
+ "5 True True True True"
+ ]
+ },
+ "execution_count": 39,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "(df1_c == df2_c)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "97d6dd4a-7f3d-4459-bf42-46d0bd087ccd",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 40,
+ "id": "35463772-bf3c-43b4-b536-cf4456b3f0f2",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from dateutil import parser"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "id": "fa9c87f8-a42c-447d-bbb3-9c9d6830bd04",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import re"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 34,
+ "id": "8d6f97a6-dafa-439e-9d9e-c37515be81bf",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pattern_dates = re.compile(r'(\\d{1,2}\\.)?(\\d{1,2}\\.)?([\\d]{2,4})?')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 35,
+ "id": "5a1a15e7-f9bb-463f-9c83-a12ff0f8328e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "dates = ['22.05.', '08.2024', '22.05.2024', 'hallo', '22.1250.25']"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 36,
+ "id": "442beb19-06ca-46ce-9d64-2e6c632ffb3c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "string = '22.1250.25'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 37,
+ "id": "21e4a7c2-76f4-43bd-aeed-34e52ed53db3",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "match = pattern_dates.search(string)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 38,
+ "id": "3f5d75f6-58dd-43a6-abf3-80f581807554",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "('22.', None, '1250')"
+ ]
+ },
+ "execution_count": 38,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "match.groups()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 83,
+ "id": "306bcd91-8b87-47fe-96d4-cbc2a2bbad88",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "dates_recog = []\n",
+ "for date in dates:\n",
+ " match = pattern_dates.search(date)\n",
+ " date_found = any(match.groups())\n",
+ " dates_recog.append(date_found)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 84,
+ "id": "4e996e9b-8d75-4060-984e-ee439bfd5d45",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[True, True, True, False, True]"
+ ]
+ },
+ "execution_count": 84,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "dates_recog"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 55,
"id": "91d4094b-f886-4056-a697-5223f157f1d3",
"metadata": {},
"outputs": [],
@@ -118,17 +700,1326 @@
},
{
"cell_type": "code",
- "execution_count": 9,
- "id": "518cada9-561a-4b96-b750-3d500d1d28b9",
+ "execution_count": null,
+ "id": "0dabae5f-89b6-4457-a4ef-17cc33c6d561",
"metadata": {},
"outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "8830bbd6-ce01-475b-b492-455400319a9d",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Loaded TOML config file successfully.\n"
+ ]
+ }
+ ],
"source": [
- "from lang_main.analysis import graphs"
+ "from lang_main import model_loader\n",
+ "from lang_main.analysis import tokens, graphs\n",
+ "\n",
+ "from lang_main.types import SpacyModelTypes"
]
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 2,
+ "id": "ee31987c-9763-4952-8d83-bf9265430e74",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "A:\\Arbeitsaufgaben\\lang-main\\.venv\\Lib\\site-packages\\thinc\\shims\\pytorch.py:261: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
+ " model.load_state_dict(torch.load(filelike, map_location=device))\n"
+ ]
+ }
+ ],
+ "source": [
+ "sentence = (\n",
+ " 'Ich ging am 22.05. mit ID 0912393 schnell über die Wiese zu einem Menschen, um ihm zu helfen. '\n",
+ " 'Ich konnte nicht mit ansehen, wie er Probleme beim Tragen '\n",
+ " 'seiner Tasche hatte.'\n",
+ ")\n",
+ "model = model_loader.load_spacy(\n",
+ " model_name=SpacyModelTypes.DE_CORE_NEWS_SM,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "e086ee66-95c3-4fbc-bd04-a16b0fcdb26a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "doc = model(sentence)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "120c886d-6f2d-48e1-a300-8f39d9771204",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from spacy import displacy"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "30b5f152-be1f-43c6-8466-98de50a28443",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ ""
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "displacy.render(doc, style=\"dep\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "id": "944b2da6-2c2a-4a58-b0ad-b2f280b7fecb",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "sent = list(doc.sents)[1]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "id": "ad8c1f0a-c46f-4b47-99d3-fa7e254ff570",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'konnte'"
+ ]
+ },
+ "execution_count": 29,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "word = sent[1]\n",
+ "word.text"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "id": "189207d4-d0e1-4b8a-be8d-f5328f37c9da",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[Ich,\n",
+ " konnte,\n",
+ " nicht,\n",
+ " mit,\n",
+ " ansehen,\n",
+ " ,,\n",
+ " wie,\n",
+ " er,\n",
+ " Probleme,\n",
+ " beim,\n",
+ " Tragen,\n",
+ " seiner,\n",
+ " Tasche,\n",
+ " hatte,\n",
+ " .]"
+ ]
+ },
+ "execution_count": 30,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "list(word.subtree)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "id": "a83bdfab-4ada-482a-b0be-d093f115a6e5",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Ich:\t\tPRON\n",
+ "konnte:\t\tAUX\n",
+ "nicht:\t\tPART\n",
+ "mit:\t\tADV\n",
+ "ansehen:\t\tVERB\n",
+ ",:\t\tPUNCT\n",
+ "wie:\t\tSCONJ\n",
+ "er:\t\tPRON\n",
+ "Probleme:\t\tNOUN\n",
+ "beim:\t\tADP\n",
+ "Tragen:\t\tNOUN\n",
+ "seiner:\t\tDET\n",
+ "Tasche:\t\tNOUN\n",
+ "hatte:\t\tVERB\n",
+ ".:\t\tPUNCT\n"
+ ]
+ }
+ ],
+ "source": [
+ "for token in word.subtree:\n",
+ " print(f'{token}:\\t\\t{token.pos_}')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "04194be3-7f30-4f02-a3ed-c2ca016652b6",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from lang_main.analysis import tokens"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "ea169167-f55e-4574-92bc-54aafc75ccc7",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'ging'"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "word.text"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "da9c5a7b-162d-4b99-b59e-f97bb765d08c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "rel_descs = tokens.obtain_relevant_descendants(word)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "2fefb0dc-8285-4f42-9323-23b0bc9d8cc0",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(0912393, schnell, Wiese, Menschen)"
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tuple(rel_descs)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "397e088f-743b-4554-a695-65d0ddaac8ce",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tk_graph = graphs.TokenGraph()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "fd46701c-e428-43f8-80d9-979e96094bf3",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tokens.add_doc_info_to_graph(tk_graph, doc, weight=2)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "id": "fc860f52-4bdb-469f-be8b-901bea39224e",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "NodeView(('gehen', '0912393', 'schnell', 'Wiese', 'Mensch', 'mit', 'Problem', 'Tragen', 'Tasche', 'ansehen', 'haben'))"
+ ]
+ },
+ "execution_count": 25,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tk_graph.nodes"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "id": "56e49d63-7374-428f-a1b0-26e3d136ab9a",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "OutEdgeView([('gehen', '0912393'), ('gehen', 'schnell'), ('gehen', 'Wiese'), ('gehen', 'Mensch'), ('mit', 'Problem'), ('mit', 'Tragen'), ('mit', 'Tasche'), ('Problem', 'Tragen'), ('Problem', 'Tasche'), ('Tragen', 'Tasche'), ('ansehen', 'mit'), ('ansehen', 'Problem'), ('ansehen', 'Tragen'), ('ansehen', 'Tasche'), ('haben', 'Problem'), ('haben', 'Tragen'), ('haben', 'Tasche')])"
+ ]
+ },
+ "execution_count": 27,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tk_graph.edges"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "id": "f0056e82-4ddc-4034-afc9-c25e3c2331b9",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "['gehen',\n",
+ " '0912393',\n",
+ " 'schnell',\n",
+ " 'Wiese',\n",
+ " 'Mensch',\n",
+ " 'mit',\n",
+ " 'Problem',\n",
+ " 'Tragen',\n",
+ " 'Tasche',\n",
+ " 'ansehen',\n",
+ " 'haben']"
+ ]
+ },
+ "execution_count": 32,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "list(tk_graph.nodes)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 56,
+ "id": "ee506f29-a6d0-47b9-a980-0227fa1d2a59",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tkg, undir = graphs.pipe_rescale_graph_edge_weights(tk)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 57,
+ "id": "29a82ea9-6a66-47d3-bdbf-e41284785bc9",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " 1 | \n",
+ " 2 | \n",
+ " 3 | \n",
+ " 4 | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 1 | \n",
+ " 0.0 | \n",
+ " 0.0952 | \n",
+ " 0.7487 | \n",
+ " 0.9830 | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 1.0 | \n",
+ " 0.0000 | \n",
+ " 0.0000 | \n",
+ " 0.8959 | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " 0.0 | \n",
+ " 0.0000 | \n",
+ " 0.0000 | \n",
+ " 0.9538 | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " 0.0 | \n",
+ " 0.0000 | \n",
+ " 0.0000 | \n",
+ " 0.0000 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " 1 2 3 4\n",
+ "1 0.0 0.0952 0.7487 0.9830\n",
+ "2 1.0 0.0000 0.0000 0.8959\n",
+ "3 0.0 0.0000 0.0000 0.9538\n",
+ "4 0.0 0.0000 0.0000 0.0000"
+ ]
+ },
+ "execution_count": 57,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "nx.to_pandas_adjacency(tkg)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 60,
+ "id": "de96d4db-0c98-4957-a91d-6d12e51fe2ee",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'weight': np.float32(1.0)}"
+ ]
+ },
+ "execution_count": 60,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "undir[2][1]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 58,
+ "id": "c802f550-5200-41f5-882e-a8eb780bacf3",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " 1 | \n",
+ " 2 | \n",
+ " 3 | \n",
+ " 4 | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 1 | \n",
+ " 0.0000 | \n",
+ " 1.0000 | \n",
+ " 0.0952 | \n",
+ " 0.9412 | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 1.0000 | \n",
+ " 0.0000 | \n",
+ " 0.0000 | \n",
+ " 0.6864 | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " 0.0952 | \n",
+ " 0.0000 | \n",
+ " 0.0000 | \n",
+ " 0.8661 | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " 0.9412 | \n",
+ " 0.6864 | \n",
+ " 0.8661 | \n",
+ " 0.0000 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " 1 2 3 4\n",
+ "1 0.0000 1.0000 0.0952 0.9412\n",
+ "2 1.0000 0.0000 0.0000 0.6864\n",
+ "3 0.0952 0.0000 0.0000 0.8661\n",
+ "4 0.9412 0.6864 0.8661 0.0000"
+ ]
+ },
+ "execution_count": 58,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "nx.to_pandas_adjacency(undir)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "84c78e9a-8b34-465c-9bc4-13b38fa0cc32",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 34,
+ "id": "58fe4954-ce69-4442-b6fa-504f1466b1dc",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "True"
+ ]
+ },
+ "execution_count": 34,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tk.has_edge(1,2)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 50,
+ "id": "baad206f-94ab-495a-8cc2-87a873220401",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " 1 | \n",
+ " 2 | \n",
+ " 3 | \n",
+ " 4 | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 1 | \n",
+ " 0.0 | \n",
+ " 7.0 | \n",
+ " 2.0 | \n",
+ " 5.0 | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 7.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 3.0 | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " 2.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 4.0 | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " 5.0 | \n",
+ " 3.0 | \n",
+ " 4.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " 1 2 3 4\n",
+ "1 0.0 7.0 2.0 5.0\n",
+ "2 7.0 0.0 0.0 3.0\n",
+ "3 2.0 0.0 0.0 4.0\n",
+ "4 5.0 3.0 4.0 0.0"
+ ]
+ },
+ "execution_count": 50,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "nx.to_pandas_adjacency(tk.undirected)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 51,
+ "id": "996a303e-02db-496a-bd23-29c92d13d260",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "True\n",
+ "True\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(tk.undirected.has_edge(1,2))\n",
+ "print(tk.undirected.has_edge(2,1))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 52,
+ "id": "5dbe02a1-3883-44d7-836b-dd2c4d27f5f8",
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "TypeError",
+ "evalue": "Graph.to_undirected() got an unexpected keyword argument 'inplace'",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[1;32mIn[52], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m filt \u001b[38;5;241m=\u001b[39m \u001b[43mgraphs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfilter_graph_by_edge_weight\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtk\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mundirected\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m6\u001b[39;49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[1;32mA:\\Arbeitsaufgaben\\lang-main\\src\\lang_main\\analysis\\graphs.py:230\u001b[0m, in \u001b[0;36mfilter_graph_by_edge_weight\u001b[1;34m(graph, bound_lower, bound_upper)\u001b[0m\n\u001b[0;32m 228\u001b[0m filtered_graph\u001b[38;5;241m.\u001b[39mremove_edge(edge[\u001b[38;5;241m0\u001b[39m], edge[\u001b[38;5;241m1\u001b[39m])\n\u001b[0;32m 229\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m bound_upper \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m weight \u001b[38;5;241m>\u001b[39m bound_upper:\n\u001b[1;32m--> 230\u001b[0m filtered_graph\u001b[38;5;241m.\u001b[39mremove_edge(edge[\u001b[38;5;241m0\u001b[39m], edge[\u001b[38;5;241m1\u001b[39m])\n\u001b[0;32m 232\u001b[0m filtered_graph\u001b[38;5;241m.\u001b[39mto_undirected(inplace\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, logging\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[0;32m 233\u001b[0m filtered_graph\u001b[38;5;241m.\u001b[39mupdate_metadata(logging\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n",
+ "\u001b[1;31mTypeError\u001b[0m: Graph.to_undirected() got an unexpected keyword argument 'inplace'"
+ ]
+ }
+ ],
+ "source": [
+ "filt = graphs.filter_graph_by_edge_weight(tk.undirected, 2, 6)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 53,
+ "id": "3d9ecb28-23ef-48ac-9cee-86ace6be7af1",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " 1 | \n",
+ " 2 | \n",
+ " 3 | \n",
+ " 4 | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 1 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 2.0 | \n",
+ " 5.0 | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 3.0 | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " 2.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 4.0 | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " 5.0 | \n",
+ " 3.0 | \n",
+ " 4.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " 1 2 3 4\n",
+ "1 0.0 0.0 2.0 5.0\n",
+ "2 0.0 0.0 0.0 3.0\n",
+ "3 2.0 0.0 0.0 4.0\n",
+ "4 5.0 3.0 4.0 0.0"
+ ]
+ },
+ "execution_count": 53,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "nx.to_pandas_adjacency(filt.undirected)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 54,
+ "id": "42345f27-585f-4498-a4cc-50d17c9f9b69",
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "ValueError",
+ "evalue": "too many values to unpack (expected 2)",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[1;31mValueError\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[1;32mIn[54], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m \u001b[43mfilt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43medges\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mweight\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\n",
+ "File \u001b[1;32mA:\\Arbeitsaufgaben\\lang-main\\.venv\\Lib\\site-packages\\networkx\\classes\\reportviews.py:1095\u001b[0m, in \u001b[0;36mOutEdgeView.__getitem__\u001b[1;34m(self, e)\u001b[0m\n\u001b[0;32m 1090\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(e, \u001b[38;5;28mslice\u001b[39m):\n\u001b[0;32m 1091\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m nx\u001b[38;5;241m.\u001b[39mNetworkXError(\n\u001b[0;32m 1092\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28mself\u001b[39m)\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m does not support slicing, \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 1093\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtry list(G.edges)[\u001b[39m\u001b[38;5;132;01m{\u001b[39;00me\u001b[38;5;241m.\u001b[39mstart\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m:\u001b[39m\u001b[38;5;132;01m{\u001b[39;00me\u001b[38;5;241m.\u001b[39mstop\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m:\u001b[39m\u001b[38;5;132;01m{\u001b[39;00me\u001b[38;5;241m.\u001b[39mstep\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m]\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 1094\u001b[0m )\n\u001b[1;32m-> 1095\u001b[0m u, v \u001b[38;5;241m=\u001b[39m e\n\u001b[0;32m 1096\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1097\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_adjdict[u][v]\n",
+ "\u001b[1;31mValueError\u001b[0m: too many values to unpack (expected 2)"
+ ]
+ }
+ ],
+ "source": [
+ "filt.edges"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "17e7c931-d94e-43cf-ac97-bb6fccc1ee70",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 48,
+ "id": "7dfa028e-d2e7-4390-bd36-b08b0a591b22",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "False"
+ ]
+ },
+ "execution_count": 48,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "filt.has_edge(1,2)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 49,
+ "id": "1b6e6938-1546-490a-9b64-e3d2f60d188d",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "False"
+ ]
+ },
+ "execution_count": 49,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "filt.has_edge(2,1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 33,
+ "id": "518cada9-561a-4b96-b750-3d500d1d28b9",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[(1, 2), (1, 3), (1, 4), (2, 4), (2, 1), (3, 4)]"
+ ]
+ },
+ "execution_count": 33,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "list(tk.edges)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "id": "9830c614-5c16-41fd-8987-be3d421da34a",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'degree_weighted': 14}"
+ ]
+ },
+ "execution_count": 29,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tk.nodes[1]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "id": "42b2bb65-534f-4c9c-b439-d5eec4b285e0",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'degree_weighted': 10}"
+ ]
+ },
+ "execution_count": 31,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tk.undirected.nodes[2]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c937e70b-bd89-4c3b-aa09-5f0a63982c13",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
"id": "3235f188-6e99-4855-aa3d-b0e04e3db319",
"metadata": {},
"outputs": [
@@ -144,7 +2035,7 @@
" 'total_memory': 448}"
]
},
- "execution_count": 10,
+ "execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
diff --git a/notebooks/test.graphml b/notebooks/test.graphml
new file mode 100644
index 0000000..58011f8
--- /dev/null
+++ b/notebooks/test.graphml
@@ -0,0 +1,37 @@
+
+
+
+
+
+
+ 14
+
+
+ 10
+
+
+ 6
+
+
+ 12
+
+
+ 1
+
+
+ 2
+
+
+ 5
+
+
+ 3
+
+
+ 6
+
+
+ 4
+
+
+
diff --git a/pyproject.toml b/pyproject.toml
index 9df6b04..f40a111 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -118,6 +118,8 @@ exclude_also = [
"def __repr__",
"def __str__",
"@overload",
+ "if logging",
+ "if TYPE_CHECKING",
]
[tool.coverage.html]
diff --git a/src/lang_main/analysis/graphs.py b/src/lang_main/analysis/graphs.py
index ec0af65..3ebce80 100644
--- a/src/lang_main/analysis/graphs.py
+++ b/src/lang_main/analysis/graphs.py
@@ -198,8 +198,10 @@ def filter_graph_by_edge_weight(
graph: TokenGraph,
bound_lower: int | None,
bound_upper: int | None,
+ property: str = 'weight',
) -> TokenGraph:
"""filters all edges which are within the provided bounds
+ inclusive limits: bound_lower <= edge_weight <= bound_upper are retained
Parameters
----------
@@ -216,12 +218,12 @@ def filter_graph_by_edge_weight(
original_graph_edges = copy.deepcopy(graph.edges)
filtered_graph = graph.copy()
- if not any([bound_lower, bound_upper]):
+ if not any((bound_lower, bound_upper)):
logger.warning('No bounds provided, returning original graph.')
return filtered_graph
for edge in original_graph_edges:
- weight = typing.cast(int, filtered_graph[edge[0]][edge[1]]['weight'])
+ weight = typing.cast(int, filtered_graph[edge[0]][edge[1]][property])
if bound_lower is not None and weight < bound_lower:
filtered_graph.remove_edge(edge[0], edge[1])
if bound_upper is not None and weight > bound_upper:
@@ -329,14 +331,12 @@ def static_graph_analysis(
Parameters
----------
tk_graph_directed : TokenGraph
- token graph (directed) and with rescaled edge weights
- tk_graph_undirected : Graph
- token graph (undirected) and with rescaled edge weights
+ token graph (directed)
Returns
-------
- tuple[TokenGraph, Graph]
- token graph (directed) and undirected version with added weighted degree
+ tuple[TokenGraph]
+ token graph (directed) with included undirected version and calculated KPIs
"""
graph = graph.copy()
graph.perform_static_analysis()
@@ -559,12 +559,12 @@ class TokenGraph(DiGraph):
return hash(self.__key())
"""
- def copy(self) -> TokenGraph:
+ def copy(self) -> Self:
"""returns a (deep) copy of the graph
Returns
-------
- TokenGraph
+ Self
deep copy of the graph
"""
return copy.deepcopy(self)
@@ -669,7 +669,7 @@ class TokenGraph(DiGraph):
return token_graph, undirected
- def perform_static_analysis(self):
+ def perform_static_analysis(self) -> None:
"""calculate different metrics directly on the data of the underlying graphs
(directed and undirected)
@@ -717,16 +717,11 @@ class TokenGraph(DiGraph):
saving_path = self._save_prepare(path=path, filename=filename)
if directed:
- target_graph = self._directed
- elif not directed and self._undirected is not None:
- target_graph = self._undirected
+ target_graph = self.directed
else:
- raise ValueError('No undirected graph available.')
+ target_graph = self.undirected
save_to_GraphML(graph=target_graph, saving_path=saving_path)
- # saving_path = saving_path.with_suffix('.graphml')
- # nx.write_graphml(G=target_graph, path=saving_path)
- # logger.info('Successfully saved graph as GraphML file under %s.', saving_path)
def to_pickle(
self,
@@ -743,13 +738,14 @@ class TokenGraph(DiGraph):
filename to be given, by default None
"""
saving_path = self._save_prepare(path=path, filename=filename)
- saving_path = saving_path.with_suffix('.pickle')
+ saving_path = saving_path.with_suffix('.pkl')
save_pickle(obj=self, path=saving_path)
@classmethod
def from_file(
cls,
path: Path,
+ node_type_graphml: type = str,
) -> Self:
# !! no validity checks for pickle files
# !! GraphML files not correct because not all properties
@@ -757,7 +753,7 @@ class TokenGraph(DiGraph):
# TODO REWORK
match path.suffix:
case '.graphml':
- graph = typing.cast(Self, nx.read_graphml(path, node_type=int))
+ graph = typing.cast(Self, nx.read_graphml(path, node_type=node_type_graphml))
logger.info('Successfully loaded graph from GraphML file %s.', path)
case '.pkl' | '.pickle':
graph = typing.cast(Self, load_pickle(path))
@@ -767,17 +763,18 @@ class TokenGraph(DiGraph):
return graph
- @classmethod
- def from_pickle(
- cls,
- path: str | Path,
- ) -> Self:
- if isinstance(path, str):
- path = Path(path)
+ # TODO check removal
+ # @classmethod
+ # def from_pickle(
+ # cls,
+ # path: str | Path,
+ # ) -> Self:
+ # if isinstance(path, str):
+ # path = Path(path)
- if path.suffix not in ('.pkl', '.pickle'):
- raise ValueError('File format not supported.')
+ # if path.suffix not in ('.pkl', '.pickle'):
+ # raise ValueError('File format not supported.')
- graph = typing.cast(Self, load_pickle(path))
+ # graph = typing.cast(Self, load_pickle(path))
- return graph
+ # return graph
diff --git a/src/lang_main/analysis/preprocessing.py b/src/lang_main/analysis/preprocessing.py
index 69dac81..dcebabd 100644
--- a/src/lang_main/analysis/preprocessing.py
+++ b/src/lang_main/analysis/preprocessing.py
@@ -205,6 +205,30 @@ def numeric_pre_filter_feature(
bound_lower: int | None,
bound_upper: int | None,
) -> tuple[DataFrame]:
+ """filter DataFrame for a given numerical feature regarding their bounds
+ bounds are inclusive: entries (bound_lower <= entry <= bound_upper) are retained
+
+ Parameters
+ ----------
+ data : DataFrame
+ DataFrame to filter
+ feature : str
+ feature name to filter
+ bound_lower : int | None
+ lower bound of values to retain
+ bound_upper : int | None
+ upper bound of values to retain
+
+ Returns
+ -------
+ tuple[DataFrame]
+ filtered DataFrame
+
+ Raises
+ ------
+ ValueError
+ if no bounds are provided, at least one bound must be set
+ """
if not any([bound_lower, bound_upper]):
raise ValueError('No bounds for filtering provided')
@@ -228,7 +252,7 @@ def numeric_pre_filter_feature(
# a more robust identification of duplicates negating negative side effects
# of several disturbances like typos, escape characters, etc.
# build mapping of embeddings for given model
-def merge_similarity_dupl(
+def merge_similarity_duplicates(
data: DataFrame,
model: SentenceTransformer,
cos_sim_threshold: float,
diff --git a/src/lang_main/analysis/tokens.py b/src/lang_main/analysis/tokens.py
index 6b35b45..f9009e7 100644
--- a/src/lang_main/analysis/tokens.py
+++ b/src/lang_main/analysis/tokens.py
@@ -11,6 +11,7 @@ from lang_main.analysis.graphs import (
TokenGraph,
update_graph,
)
+from lang_main.analysis.shared import pattern_dates
from lang_main.constants import (
POS_INDIRECT,
POS_OF_INTEREST,
@@ -38,21 +39,40 @@ def is_str_date(
string: str,
fuzzy: bool = False,
) -> bool:
+ """not stable function to test strings for dates, not 100 percent reliable
+
+ Parameters
+ ----------
+ string : str
+ string to check for dates
+ fuzzy : bool, optional
+ whether to use dateutils.parser.pase fuzzy capability, by default False
+
+ Returns
+ -------
+ bool
+ indicates whether date was found or not
+ """
try:
# check if string is a number
# if length is greater than 8, it is not a date
int(string)
- if len(string) > 8:
+ if len(string) not in {2, 4}:
return False
except ValueError:
# not a number
pass
try:
- parse(string, fuzzy=fuzzy)
+ parse(string, fuzzy=fuzzy, dayfirst=True, yearfirst=False)
return True
except ValueError:
- return False
+ date_found: bool = False
+ match = pattern_dates.search(string)
+ if match is None:
+ return date_found
+ date_found = any(match.groups())
+ return date_found
def obtain_relevant_descendants(
@@ -106,7 +126,7 @@ def add_doc_info_to_graph(
if not (token.pos_ in POS_OF_INTEREST or token.tag_ in TAG_OF_INTEREST):
continue
# skip token which are dates or times
- if is_str_date(string=token.text):
+ if token.pos_ == 'NUM' and is_str_date(string=token.text):
continue
relevant_descendants = obtain_relevant_descendants(token=token)
@@ -252,32 +272,33 @@ def build_token_graph_simple(
return graph, docs_mapping
-def build_token_graph_old(
- data: DataFrame,
- model: SpacyModel,
-) -> tuple[TokenGraph]:
- # empty NetworkX directed graph
- # graph = nx.DiGraph()
- graph = TokenGraph()
+# TODO check removal
+# def build_token_graph_old(
+# data: DataFrame,
+# model: SpacyModel,
+# ) -> tuple[TokenGraph]:
+# # empty NetworkX directed graph
+# # graph = nx.DiGraph()
+# graph = TokenGraph()
- for row in tqdm(data.itertuples(), total=len(data)):
- # obtain properties from tuple
- # attribute names must match with preprocessed data
- entry_text = cast(str, row.entry)
- weight = cast(int, row.num_occur)
+# for row in tqdm(data.itertuples(), total=len(data)):
+# # obtain properties from tuple
+# # attribute names must match with preprocessed data
+# entry_text = cast(str, row.entry)
+# weight = cast(int, row.num_occur)
- # get spacy model output
- doc = model(entry_text)
+# # get spacy model output
+# doc = model(entry_text)
- add_doc_info_to_graph(
- graph=graph,
- doc=doc,
- weight=weight,
- )
+# add_doc_info_to_graph(
+# graph=graph,
+# doc=doc,
+# weight=weight,
+# )
- # metadata
- graph.update_metadata()
- # convert to undirected
- graph.to_undirected()
+# # metadata
+# graph.update_metadata()
+# # convert to undirected
+# graph.to_undirected()
- return (graph,)
+# return (graph,)
diff --git a/src/lang_main/constants.py b/src/lang_main/constants.py
index 4f27ccf..7b7f50c 100644
--- a/src/lang_main/constants.py
+++ b/src/lang_main/constants.py
@@ -43,6 +43,9 @@ LOGGING_TO_FILE: Final[bool] = CONFIG['logging']['file']
LOGGING_TO_STDERR: Final[bool] = CONFIG['logging']['stderr']
LOGGING_DEFAULT_GRAPHS: Final[bool] = False
+# ** pickling
+PICKLE_PROTOCOL_VERSION: Final[int] = 5
+
# ** paths
input_path_conf = Path.cwd() / Path(CONFIG['paths']['inputs'])
INPUT_PATH_FOLDER: Final[Path] = input_path_conf.resolve()
@@ -91,12 +94,7 @@ else:
STFR_MODEL_ARGS: Final[STFRModelArgs] = stfr_model_args
# ** language dependency analysis
# ** POS
-# POS_OF_INTEREST: frozenset[str] = frozenset(['NOUN', 'PROPN', 'ADJ', 'VERB', 'AUX'])
-# POS_OF_INTEREST: frozenset[str] = frozenset(['NOUN', 'ADJ', 'VERB', 'AUX'])
-# POS_OF_INTEREST: frozenset[str] = frozenset(['NOUN', 'PROPN'])
-# POS_OF_INTEREST: frozenset[str] = frozenset(['NOUN', 'PROPN', 'VERB', 'AUX'])
-POS_OF_INTEREST: frozenset[str] = frozenset(['NOUN', 'PROPN', 'VERB', 'AUX', 'ADV'])
-# POS_INDIRECT: frozenset[str] = frozenset(['AUX', 'VERB'])
+POS_OF_INTEREST: frozenset[str] = frozenset(['NOUN', 'PROPN', 'VERB', 'AUX', 'ADV', 'NUM'])
POS_INDIRECT: frozenset[str] = frozenset(['AUX'])
# ** TAG
# TAG_OF_INTEREST: frozenset[str] = frozenset(['ADJD'])
diff --git a/src/lang_main/io.py b/src/lang_main/io.py
index f5b3af4..9b985eb 100644
--- a/src/lang_main/io.py
+++ b/src/lang_main/io.py
@@ -4,6 +4,7 @@ import shutil
from pathlib import Path
from typing import Any
+from lang_main.constants import PICKLE_PROTOCOL_VERSION
from lang_main.loggers import logger_shared_helpers as logger
@@ -39,7 +40,7 @@ def save_pickle(
path: str | Path,
) -> None:
with open(path, 'wb') as file:
- pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL)
+ pickle.dump(obj, file, protocol=PICKLE_PROTOCOL_VERSION)
logger.info('Saved file successfully under %s', path)
@@ -56,7 +57,7 @@ def encode_to_base64_str(
obj: Any,
encoding: str = 'utf-8',
) -> str:
- serialised = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
+ serialised = pickle.dumps(obj, protocol=PICKLE_PROTOCOL_VERSION)
b64_bytes = base64.b64encode(serialised)
return b64_bytes.decode(encoding=encoding)
diff --git a/src/lang_main/pipelines/predefined.py b/src/lang_main/pipelines/predefined.py
index 0e073a0..22bb04d 100644
--- a/src/lang_main/pipelines/predefined.py
+++ b/src/lang_main/pipelines/predefined.py
@@ -5,7 +5,7 @@ from lang_main.analysis import graphs
from lang_main.analysis.preprocessing import (
analyse_feature,
load_raw_data,
- merge_similarity_dupl,
+ merge_similarity_duplicates,
numeric_pre_filter_feature,
remove_duplicates,
remove_NA,
@@ -100,7 +100,7 @@ def build_merge_duplicates_pipe() -> Pipeline:
},
)
pipe_merge.add(
- merge_similarity_dupl,
+ merge_similarity_duplicates,
{
'model': STFR_MODEL,
'cos_sim_threshold': THRESHOLD_SIMILARITY,
diff --git a/tests/Dummy_Dataset_N_1000.csv b/tests/_comparison_results/Dummy_Dataset_N_1000.csv
similarity index 100%
rename from tests/Dummy_Dataset_N_1000.csv
rename to tests/_comparison_results/Dummy_Dataset_N_1000.csv
diff --git a/tests/_comparison_results/analyse_feature.pkl b/tests/_comparison_results/analyse_feature.pkl
new file mode 100644
index 0000000..359919b
Binary files /dev/null and b/tests/_comparison_results/analyse_feature.pkl differ
diff --git a/tests/analyse_dataset.xlsx b/tests/_comparison_results/analyse_feature.xlsx
similarity index 100%
rename from tests/analyse_dataset.xlsx
rename to tests/_comparison_results/analyse_feature.xlsx
diff --git a/tests/_comparison_results/merge_cands.xlsx b/tests/_comparison_results/merge_cands.xlsx
new file mode 100644
index 0000000..e25b347
Binary files /dev/null and b/tests/_comparison_results/merge_cands.xlsx differ
diff --git a/tests/_comparison_results/merge_similarity_candidates.pkl b/tests/_comparison_results/merge_similarity_candidates.pkl
new file mode 100644
index 0000000..02e1c74
Binary files /dev/null and b/tests/_comparison_results/merge_similarity_candidates.pkl differ
diff --git a/tests/_comparison_results/numeric_pre_filter.pkl b/tests/_comparison_results/numeric_pre_filter.pkl
new file mode 100644
index 0000000..c3013ef
Binary files /dev/null and b/tests/_comparison_results/numeric_pre_filter.pkl differ
diff --git a/tests/_comparison_results/tk_graph_built.pkl b/tests/_comparison_results/tk_graph_built.pkl
new file mode 100644
index 0000000..15b3a5d
Binary files /dev/null and b/tests/_comparison_results/tk_graph_built.pkl differ
diff --git a/tests/analysis/test_graphs.py b/tests/analysis/test_graphs.py
index b962f30..2d59df3 100644
--- a/tests/analysis/test_graphs.py
+++ b/tests/analysis/test_graphs.py
@@ -2,6 +2,7 @@ import networkx as nx
import pytest
from lang_main.analysis import graphs
+from lang_main.errors import EmptyEdgesError, EmptyGraphError, EdgePropertyNotContainedError
TK_GRAPH_NAME = 'TEST_TOKEN_GRAPH'
@@ -40,13 +41,18 @@ def build_init_graph(token_graph: bool):
@pytest.fixture(scope='module')
-def graph():
+def graph() -> graphs.DiGraph:
return build_init_graph(token_graph=False)
@pytest.fixture(scope='module')
-def tk_graph():
- return build_init_graph(token_graph=True)
+def tk_graph() -> graphs.TokenGraph:
+ return build_init_graph(token_graph=True) # type: ignore
+
+
+@pytest.fixture(scope='module')
+def tk_graph_undirected(tk_graph) -> graphs.Graph:
+ return tk_graph.undirected
def test_graph_size(graph):
@@ -61,7 +67,45 @@ def test_save_to_GraphML(graph, tmp_path):
assert saved_file.exists()
-def test_metadata_retrieval(graph):
+def test_save_load_pickle_tk_graph(tk_graph, tmp_path):
+ filename = 'test_save_tkg'
+ tk_graph.to_pickle(tmp_path, filename)
+ load_pth = (tmp_path / filename).with_suffix('.pkl')
+ assert load_pth.exists()
+ loaded_graph = graphs.TokenGraph.from_file(load_pth)
+ assert loaded_graph.nodes == tk_graph.nodes
+ assert loaded_graph.edges == tk_graph.edges
+ filename = None
+ tk_graph.to_pickle(tmp_path, filename)
+ load_pth = (tmp_path / tk_graph.name).with_suffix('.pkl')
+ assert load_pth.exists()
+ loaded_graph = graphs.TokenGraph.from_file(load_pth)
+ assert loaded_graph.nodes == tk_graph.nodes
+ assert loaded_graph.edges == tk_graph.edges
+
+
+@pytest.mark.parametrize(
+ 'import_graph,directed', [('tk_graph', True), ('tk_graph_undirected', False)]
+)
+def test_save_load_GraphML_tk_graph(import_graph, tk_graph, directed, tmp_path, request):
+ test_graph = request.getfixturevalue(import_graph)
+ filename = 'test_save_tkg'
+ tk_graph.to_GraphML(tmp_path, filename, directed=directed)
+ load_pth = (tmp_path / filename).with_suffix('.graphml')
+ assert load_pth.exists()
+ loaded_graph = graphs.TokenGraph.from_file(load_pth, node_type_graphml=int)
+ assert loaded_graph.nodes == test_graph.nodes
+ assert loaded_graph.edges == test_graph.edges
+ filename = None
+ tk_graph.to_GraphML(tmp_path, filename, directed=directed)
+ load_pth = (tmp_path / tk_graph.name).with_suffix('.graphml')
+ assert load_pth.exists()
+ loaded_graph = graphs.TokenGraph.from_file(load_pth, node_type_graphml=int)
+ assert loaded_graph.nodes == test_graph.nodes
+ assert loaded_graph.edges == test_graph.edges
+
+
+def test_get_graph_metadata(graph):
metadata = graphs.get_graph_metadata(graph)
assert metadata['num_nodes'] == 4
assert metadata['num_edges'] == 6
@@ -72,7 +116,7 @@ def test_metadata_retrieval(graph):
assert metadata['total_memory'] == 448
-def test_graph_update_batch():
+def test_update_graph_batch():
graph_obj = build_init_graph(token_graph=False)
graphs.update_graph(graph_obj, batch=((4, 5), (5, 6)), weight_connection=8)
metadata = graphs.get_graph_metadata(graph_obj)
@@ -82,7 +126,7 @@ def test_graph_update_batch():
assert metadata['max_edge_weight'] == 8
-def test_graph_update_single_new():
+def test_update_graph_single_new():
graph_obj = build_init_graph(token_graph=False)
graphs.update_graph(graph_obj, parent=4, child=5, weight_connection=7)
metadata = graphs.get_graph_metadata(graph_obj)
@@ -92,7 +136,7 @@ def test_graph_update_single_new():
assert metadata['max_edge_weight'] == 7
-def test_graph_update_single_existing():
+def test_update_graph_single_existing():
graph_obj = build_init_graph(token_graph=False)
graphs.update_graph(graph_obj, parent=1, child=4, weight_connection=5)
metadata = graphs.get_graph_metadata(graph_obj)
@@ -103,13 +147,13 @@ def test_graph_update_single_existing():
@pytest.mark.parametrize('cast_int', [True, False])
-def test_graph_undirected_conversion(graph, cast_int):
+def test_convert_graph_to_undirected(graph, cast_int):
graph_undir = graphs.convert_graph_to_undirected(graph, cast_int=cast_int)
# edges: (1, 2, w=1) und (2, 1, w=6) --> undirected: (1, 2, w=7)
assert graph_undir[1][2]['weight'] == pytest.approx(7.0)
-def test_graph_cytoscape_conversion(graph):
+def test_convert_graph_to_cytoscape(graph):
cyto_graph, weight_data = graphs.convert_graph_to_cytoscape(graph)
node = cyto_graph[0]
edge = cyto_graph[-1]
@@ -144,7 +188,17 @@ def test_tk_graph_properties(tk_graph):
assert metadata_undirected['total_memory'] == 392
-def test_graph_degree_filter(tk_graph):
+def test_filter_graph_by_edge_weight(tk_graph):
+ filtered_graph = graphs.filter_graph_by_edge_weight(
+ tk_graph,
+ bound_lower=2,
+ bound_upper=5,
+ )
+ assert not filtered_graph.has_edge(1, 2)
+ assert not filtered_graph.has_edge(2, 1)
+
+
+def test_filter_graph_by_node_degree(tk_graph):
filtered_graph = graphs.filter_graph_by_node_degree(
tk_graph,
bound_lower=3,
@@ -153,7 +207,7 @@ def test_graph_degree_filter(tk_graph):
assert len(filtered_graph.nodes) == 2
-def test_graph_edge_number_filter(tk_graph):
+def test_filter_graph_by_number_edges(tk_graph):
number_edges_limit = 1
filtered_graph = graphs.filter_graph_by_number_edges(
tk_graph,
@@ -166,3 +220,75 @@ def test_graph_edge_number_filter(tk_graph):
bound_upper=None,
)
assert len(filtered_graph.nodes) == 2, 'one edge should result in only two nodes'
+
+
+def test_add_weighted_degree():
+ graph_obj = build_init_graph(token_graph=False)
+ property_name = 'degree_weighted'
+ graphs.add_weighted_degree(graph_obj, 'weight', property_name)
+ assert graph_obj.nodes[1][property_name] == 14
+ assert graph_obj.nodes[2][property_name] == 10
+ assert graph_obj.nodes[3][property_name] == 6
+
+
+def test_static_graph_analysis():
+ graph_obj = build_init_graph(token_graph=True)
+ (graph_obj,) = graphs.static_graph_analysis(graph_obj) # type: ignore
+ property_name = 'degree_weighted'
+ assert graph_obj.nodes[1][property_name] == 14
+ assert graph_obj.nodes[2][property_name] == 10
+ assert graph_obj.nodes[3][property_name] == 6
+ assert graph_obj.undirected.nodes[1][property_name] == 14
+ assert graph_obj.undirected.nodes[2][property_name] == 10
+ assert graph_obj.undirected.nodes[3][property_name] == 6
+
+
+def test_pipe_add_graph_metrics():
+ graph_obj = build_init_graph(token_graph=False)
+ graph_obj_undir = graphs.convert_graph_to_undirected(graph_obj, cast_int=True)
+ graph_collection = graphs.pipe_add_graph_metrics(graph_obj, graph_obj_undir)
+ property_name = 'degree_weighted'
+ assert graph_collection[0].nodes[1][property_name] == 14
+ assert graph_collection[0].nodes[2][property_name] == 10
+ assert graph_collection[0].nodes[3][property_name] == 6
+ assert graph_collection[1].nodes[1][property_name] == 14
+ assert graph_collection[1].nodes[2][property_name] == 10
+ assert graph_collection[1].nodes[3][property_name] == 6
+
+
+def test_pipe_rescale_graph_edge_weights(tk_graph):
+ rescaled_tkg, rescaled_undir = graphs.pipe_rescale_graph_edge_weights(tk_graph)
+ assert rescaled_tkg[2][1]['weight'] == pytest.approx(1.0)
+ assert rescaled_tkg[1][2]['weight'] == pytest.approx(0.0952)
+ assert rescaled_undir[2][1]['weight'] == pytest.approx(1.0)
+ assert rescaled_undir[1][2]['weight'] == pytest.approx(1.0)
+
+
+@pytest.mark.parametrize('import_graph', ['graph', 'tk_graph'])
+def test_rescale_edge_weights(import_graph, request):
+ test_graph = request.getfixturevalue(import_graph)
+ rescaled_graph = graphs.rescale_edge_weights(test_graph)
+ assert rescaled_graph[2][1]['weight'] == pytest.approx(1.0)
+ assert rescaled_graph[1][2]['weight'] == pytest.approx(0.0952)
+
+
+@pytest.mark.parametrize('import_graph', ['graph', 'tk_graph'])
+def test_verify_property(import_graph, request):
+ test_graph = request.getfixturevalue(import_graph)
+ test_property = 'centrality'
+ with pytest.raises(EdgePropertyNotContainedError):
+ graphs.verify_property(test_graph, property=test_property)
+ test_property = 'weight'
+ assert not graphs.verify_property(test_graph, property=test_property)
+
+
+def test_verify_non_empty_graph():
+ graph = nx.Graph()
+ with pytest.raises(EmptyGraphError):
+ graphs.verify_non_empty_graph(graph)
+ graph.add_nodes_from([1, 2, 3, 4])
+ with pytest.raises(EmptyEdgesError):
+ graphs.verify_non_empty_graph(graph, including_edges=True)
+ assert not graphs.verify_non_empty_graph(graph, including_edges=False)
+ graph.add_edges_from([(1, 2), (1, 3), (2, 4)])
+ assert not graphs.verify_non_empty_graph(graph, including_edges=True)
diff --git a/tests/analysis/test_preprocessing.py b/tests/analysis/test_preprocessing.py
index eb6caf9..bc87f15 100644
--- a/tests/analysis/test_preprocessing.py
+++ b/tests/analysis/test_preprocessing.py
@@ -2,8 +2,11 @@
executed in in a pipeline
"""
+from pathlib import Path
+from lang_main import model_loader
from lang_main.analysis import preprocessing as ppc
from lang_main.analysis import shared
+from lang_main.types import LanguageModels, STFRModelTypes
def test_load_data(raw_data_path, raw_data_date_cols):
@@ -71,3 +74,43 @@ def test_analyse_feature(raw_data_path, raw_data_date_cols):
(data,) = ppc.analyse_feature(data, target_feature=target_features[0])
assert len(data) == 139
+
+
+def test_numeric_pre_filter_feature(data_analyse_feature, data_numeric_pre_filter_feature):
+ # Dataset contains 139 entries. The feature "len" has a minimum value of 15,
+ # which occurs only once. If all values >= are retained only one entry should be
+ # filtered. This results in a total number of 138 entries.
+ (data,) = ppc.numeric_pre_filter_feature(
+ data=data_analyse_feature,
+ feature='len',
+ bound_lower=16,
+ bound_upper=None,
+ )
+ assert len(data) == 138
+ eval_merged = data[['entry', 'len', 'num_occur', 'num_assoc_obj_ids']]
+ eval_benchmark = data_numeric_pre_filter_feature[
+ ['entry', 'len', 'num_occur', 'num_assoc_obj_ids']
+ ]
+ assert bool((eval_merged == eval_benchmark).all(axis=None))
+
+
+def test_merge_similarity_duplicates(data_analyse_feature, data_merge_similarity_duplicates):
+ cos_sim_threshold = 0.8
+ # reduce dataset to 10 entries
+ data = data_analyse_feature.iloc[:10]
+ model = model_loader.load_sentence_transformer(
+ model_name=STFRModelTypes.ALL_MPNET_BASE_V2,
+ )
+ (merged_data,) = ppc.merge_similarity_duplicates(
+ data=data,
+ model=model,
+ cos_sim_threshold=cos_sim_threshold,
+ )
+ # constructed use case: with this threshold,
+ # 2 out of 10 entries are merged into one
+ assert len(merged_data) == 9
+ eval_merged = merged_data[['entry', 'len', 'num_occur', 'num_assoc_obj_ids']]
+ eval_benchmark = data_merge_similarity_duplicates[
+ ['entry', 'len', 'num_occur', 'num_assoc_obj_ids']
+ ]
+ assert bool((eval_merged == eval_benchmark).all(axis=None))
diff --git a/tests/analysis/test_tokens.py b/tests/analysis/test_tokens.py
new file mode 100644
index 0000000..dc16ef2
--- /dev/null
+++ b/tests/analysis/test_tokens.py
@@ -0,0 +1,79 @@
+from pathlib import Path
+
+import pytest
+
+from lang_main import model_loader
+from lang_main.analysis import graphs, tokens
+from lang_main.types import SpacyModelTypes
+
+SENTENCE = (
+ 'Ich ging am 22.05. mit ID 0912393 schnell über die Wiese zu einem Menschen, '
+ 'um ihm zu helfen. Ich konnte nicht mit ansehen, wie er Probleme beim Tragen '
+ 'seiner Tasche hatte.'
+)
+
+
+@pytest.fixture(scope='module')
+def spacy_model():
+ model = model_loader.load_spacy(
+ model_name=SpacyModelTypes.DE_CORE_NEWS_SM,
+ )
+ return model
+
+
+def test_pre_clean_word():
+ string = 'Öl3bad2024prüfung'
+ assert tokens.pre_clean_word(string) == 'Ölbadprüfung'
+
+
+def test_is_str_date():
+ string = '22.05.'
+ assert tokens.is_str_date(string, fuzzy=True)
+ string = '22.05.2024'
+ assert tokens.is_str_date(string)
+ string = '22-05-2024'
+ assert tokens.is_str_date(string)
+ string = '9009090909'
+ assert not tokens.is_str_date(string)
+ string = 'hello347'
+ assert not tokens.is_str_date(string)
+
+
+# TODO: depends on fixed Constants
+def test_obtain_relevant_descendants(spacy_model):
+ doc = spacy_model(SENTENCE)
+ sent1 = tuple(doc.sents)[0] # first sentence
+ word1 = sent1[1] # word "ging" (POS:VERB)
+ descendants1 = ('0912393', 'schnell', 'Wiese', 'Menschen')
+ rel_descs = tokens.obtain_relevant_descendants(word1)
+ rel_descs = tuple((token.text for token in rel_descs))
+ assert descendants1 == rel_descs
+
+ sent2 = tuple(doc.sents)[1] # first sentence
+ word2 = sent2[1] # word "konnte" (POS:AUX)
+ descendants2 = ('mit', 'Probleme', 'Tragen', 'Tasche')
+ rel_descs = tokens.obtain_relevant_descendants(word2)
+ rel_descs = tuple((token.text for token in rel_descs))
+ assert descendants2 == rel_descs
+
+
+def test_add_doc_info_to_graph(spacy_model):
+ doc = spacy_model(SENTENCE)
+ tk_graph = graphs.TokenGraph()
+ tokens.add_doc_info_to_graph(tk_graph, doc, weight=2)
+ assert len(tk_graph.nodes) == 11
+ assert len(tk_graph.edges) == 17
+ assert '0912393' in tk_graph.nodes
+
+
+def test_build_token_graph(
+ data_merge_similarity_duplicates,
+ spacy_model,
+ data_tk_graph_built,
+):
+ tk_graph, _ = tokens.build_token_graph(
+ data=data_merge_similarity_duplicates,
+ model=spacy_model,
+ )
+ assert len(tk_graph.nodes) == len(data_tk_graph_built.nodes)
+ assert len(tk_graph.edges) == len(data_tk_graph_built.edges)
diff --git a/tests/conftest.py b/tests/conftest.py
index 244efcf..c2f44e6 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,5 +1,7 @@
from pathlib import Path
+from lang_main.analysis import graphs
+import pandas as pd
import pytest
DATE_COLS: tuple[str, ...] = (
@@ -12,7 +14,7 @@ DATE_COLS: tuple[str, ...] = (
@pytest.fixture(scope='session')
def raw_data_path():
- pth_data = Path('./tests/Dummy_Dataset_N_1000.csv')
+ pth_data = Path('./tests/_comparison_results/Dummy_Dataset_N_1000.csv')
assert pth_data.exists()
return pth_data
@@ -21,3 +23,27 @@ def raw_data_path():
@pytest.fixture(scope='session')
def raw_data_date_cols():
return DATE_COLS
+
+
+@pytest.fixture(scope='session')
+def data_analyse_feature() -> pd.DataFrame:
+ pth_data = Path('./tests/_comparison_results/analyse_feature.pkl')
+ return pd.read_pickle(pth_data)
+
+
+@pytest.fixture(scope='session')
+def data_numeric_pre_filter_feature() -> pd.DataFrame:
+ pth_data = Path('./tests/_comparison_results/numeric_pre_filter.pkl')
+ return pd.read_pickle(pth_data)
+
+
+@pytest.fixture(scope='session')
+def data_merge_similarity_duplicates() -> pd.DataFrame:
+ pth_data = Path('./tests/_comparison_results/merge_similarity_candidates.pkl')
+ return pd.read_pickle(pth_data)
+
+
+@pytest.fixture(scope='session')
+def data_tk_graph_built():
+ pth_data = Path('./tests/_comparison_results/tk_graph_built.pkl')
+ return graphs.TokenGraph.from_file(pth_data)