Javihaus commited on
Commit
530fbcd
·
verified ·
1 Parent(s): af2d3d2

Upload Hamiltonian_final_version.ipynb

Browse files
Notebook/Hamiltonian_final_version.ipynb ADDED
@@ -0,0 +1,1817 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": []
7
+ },
8
+ "kernelspec": {
9
+ "name": "python3",
10
+ "display_name": "Python 3"
11
+ },
12
+ "language_info": {
13
+ "name": "python"
14
+ }
15
+ },
16
+ "cells": [
17
+ {
18
+ "cell_type": "code",
19
+ "source": [
20
+ "\"\"\"\n",
21
+ "This script implements corresponds to the experiments conducted for\n",
22
+ "weitting the paper \"Optimizing AI Reasoning: A Hamiltonian Dynamics Approach to\n",
23
+ "Multi-Hop Question Answering\".\n",
24
+ "\n",
25
+ "Author: Javier Marín\n",
26
+ "Email: [email protected]\n",
27
+ "Version: 1.0.0\n",
28
+ "Date: October 65, 2024\n",
29
+ "\n",
30
+ "License: MIT License\n",
31
+ "\n",
32
+ "Copyright (c) 2024 Javier Marín\n",
33
+ "\n",
34
+ "Permission is hereby granted, free of charge, to any person obtaining a copy\n",
35
+ "of this software and associated documentation files (the \"Software\"), to deal\n",
36
+ "in the Software without restriction, including without limitation the rights\n",
37
+ "to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n",
38
+ "copies of the Software, and to permit persons to whom the Software is\n",
39
+ "furnished to do so, subject to the following conditions:\n",
40
+ "\n",
41
+ "The above copyright notice and this permission notice shall be included in all\n",
42
+ "copies or substantial portions of the Software.\n",
43
+ "\n",
44
+ "THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n",
45
+ "IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n",
46
+ "FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n",
47
+ "AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n",
48
+ "LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n",
49
+ "OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n",
50
+ "SOFTWARE.\n",
51
+ "\n",
52
+ "Dependencies:\n",
53
+ "- Python 3.8+\n",
54
+ "- NumPy\n",
55
+ "- Pandas\n",
56
+ "- PyTorch\n",
57
+ "- Transformers\n",
58
+ "- Scikit-learn\n",
59
+ "- SciPy\n",
60
+ "- Statsmodels\n",
61
+ "- Matplotlib\n",
62
+ "- Seaborn\n",
63
+ "\n",
64
+ "For a full list of dependencies and their versions, see requirements.txt\n",
65
+ "\"\"\""
66
+ ],
67
+ "metadata": {
68
+ "id": "T-57ivc-aTrA"
69
+ },
70
+ "execution_count": null,
71
+ "outputs": []
72
+ },
73
+ {
74
+ "cell_type": "markdown",
75
+ "source": [
76
+ "## Imports"
77
+ ],
78
+ "metadata": {
79
+ "id": "QUcpzyBmWpLv"
80
+ }
81
+ },
82
+ {
83
+ "cell_type": "code",
84
+ "execution_count": null,
85
+ "metadata": {
86
+ "id": "l2rfFoVtIL6_"
87
+ },
88
+ "outputs": [],
89
+ "source": [
90
+ "# Standard library imports\n",
91
+ "import os\n",
92
+ "import re\n",
93
+ "import time\n",
94
+ "\n",
95
+ "# Third-party imports\n",
96
+ "import numpy as np\n",
97
+ "import pandas as pd\n",
98
+ "import torch\n",
99
+ "import seaborn as sns\n",
100
+ "import matplotlib.pyplot as plt\n",
101
+ "from mpl_toolkits.mplot3d import Axes3D\n",
102
+ "\n",
103
+ "from transformers import AutoTokenizer, AutoModel\n",
104
+ "from statsmodels.multivariate.manova import MANOVA\n",
105
+ "from scipy import stats\n",
106
+ "from scipy.optimize import curve_fit\n",
107
+ "from scipy.integrate import odeint\n",
108
+ "from sklearn import (\n",
109
+ " metrics,\n",
110
+ " model_selection,\n",
111
+ " cluster,\n",
112
+ " decomposition,\n",
113
+ " feature_extraction,\n",
114
+ " linear_model\n",
115
+ ")\n",
116
+ "\n",
117
+ "# Visualization settings\n",
118
+ "sns.set_theme(style=\"whitegrid\", context=\"paper\")\n",
119
+ "plt.rcParams['font.family'] = 'serif'\n",
120
+ "plt.rcParams['font.serif'] = ['Times New Roman'] + plt.rcParams['font.serif']"
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "markdown",
125
+ "source": [
126
+ "## Load BERT pretrained model"
127
+ ],
128
+ "metadata": {
129
+ "id": "4nApCVrOWkR3"
130
+ }
131
+ },
132
+ {
133
+ "cell_type": "code",
134
+ "source": [
135
+ "# Load pre-trained model and tokenizer\n",
136
+ "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n",
137
+ "model = AutoModel.from_pretrained(\"bert-base-uncased\")"
138
+ ],
139
+ "metadata": {
140
+ "id": "hT2I1H8BIOp_"
141
+ },
142
+ "execution_count": null,
143
+ "outputs": []
144
+ },
145
+ {
146
+ "cell_type": "markdown",
147
+ "source": [
148
+ "## Load data"
149
+ ],
150
+ "metadata": {
151
+ "id": "9KKw24bCWgWj"
152
+ }
153
+ },
154
+ {
155
+ "cell_type": "code",
156
+ "source": [
157
+ "# Load the OBQA dataset\n",
158
+ "df = pd.read_csv(\"obqa_chains.csv\", sep=\";\")\n",
159
+ "\n",
160
+ "# Ensure necessary columns exist\n",
161
+ "required_columns = ['QID', 'Chain#', 'Question', 'Answer', 'Fact1', 'Fact2', 'Turk']\n",
162
+ "missing_columns = [col for col in required_columns if col not in df.columns]\n",
163
+ "if missing_columns:\n",
164
+ " raise ValueError(f\"Missing required columns: {missing_columns}\")\n",
165
+ "\n",
166
+ "# Preprocess the data\n",
167
+ "df['Question'] = df['Question'] + \" \" + df['Answer'] # Combine question and answer\n",
168
+ "df['is_valid'] = df['Turk'].str.contains('yes', case=False, na=False)"
169
+ ],
170
+ "metadata": {
171
+ "id": "g2f-T9koIOjH"
172
+ },
173
+ "execution_count": null,
174
+ "outputs": []
175
+ },
176
+ {
177
+ "cell_type": "markdown",
178
+ "source": [
179
+ "## Model embeddings"
180
+ ],
181
+ "metadata": {
182
+ "id": "XdN9XTGOWdsh"
183
+ }
184
+ },
185
+ {
186
+ "cell_type": "code",
187
+ "source": [
188
+ "def get_bert_embedding(text):\n",
189
+ " \"\"\"Get BERT embedding for a given text.\"\"\"\n",
190
+ " inputs = tokenizer(text, return_tensors=\"pt\", padding=True, truncation=True, max_length=512)\n",
191
+ " with torch.no_grad():\n",
192
+ " outputs = model(**inputs)\n",
193
+ " return outputs.last_hidden_state.mean(dim=1).squeeze().numpy()\n",
194
+ "\n",
195
+ "def refined_hamiltonian_energy(chain):\n",
196
+ " emb1 = get_bert_embedding(chain['Fact1'])\n",
197
+ " emb2 = get_bert_embedding(chain['Fact2'])\n",
198
+ " emb_q = get_bert_embedding(chain['Question'])\n",
199
+ "\n",
200
+ " # Refined kinetic term: measure of change between facts\n",
201
+ " T = np.linalg.norm(emb2 - emb1)\n",
202
+ "\n",
203
+ " # Refined potential term: measure of relevance to question\n",
204
+ " V = (np.dot(emb1, emb_q) + np.dot(emb2, emb_q)) / 2\n",
205
+ "\n",
206
+ " # Total \"Hamiltonian\" energy: balance between change and relevance\n",
207
+ " H = T - V\n",
208
+ "\n",
209
+ " return H, T, V\n",
210
+ "\n",
211
+ "\n",
212
+ "# Analyze energy conservation\n",
213
+ "def energy_conservation_score(chain):\n",
214
+ " _, T, V = refined_hamiltonian_energy(chain)\n",
215
+ " # Measure how balanced T and V are\n",
216
+ " return 1 / (1 + abs(T - V)) # Now always between 0 and 1, 1 being perfect balance\n",
217
+ "\n",
218
+ "\n",
219
+ "\n",
220
+ "# Calculate refined energies and scores\n",
221
+ "df['H_energy'], df['T_energy'], df['V_energy'] = zip(*df.apply(refined_hamiltonian_energy, axis=1))\n",
222
+ "df['energy_conservation'] = df.apply(energy_conservation_score, axis=1)"
223
+ ],
224
+ "metadata": {
225
+ "id": "3q4EMfekIOZ_"
226
+ },
227
+ "execution_count": null,
228
+ "outputs": []
229
+ },
230
+ {
231
+ "cell_type": "markdown",
232
+ "source": [
233
+ "## Hamiltonian systems"
234
+ ],
235
+ "metadata": {
236
+ "id": "pvQgqhW2Wage"
237
+ }
238
+ },
239
+ {
240
+ "cell_type": "code",
241
+ "source": [
242
+ "def get_trajectory(row):\n",
243
+ " # Ensure we're working with strings\n",
244
+ " chain = [str(row['Fact1']), str(row['Fact2'])]\n",
245
+ " embeddings = [get_bert_embedding(sentence) for sentence in chain]\n",
246
+ " return np.array(embeddings)\n",
247
+ "\n",
248
+ "def refined_hamiltonian_energy(chain):\n",
249
+ " emb1 = get_bert_embedding(chain['Fact1'])\n",
250
+ " emb2 = get_bert_embedding(chain['Fact2'])\n",
251
+ "\n",
252
+ " # Refined kinetic term: measure of change between facts\n",
253
+ " T = np.linalg.norm(emb2 - emb1)\n",
254
+ "\n",
255
+ " # Refined potential term: measure of relevance to facts\n",
256
+ " V = (np.linalg.norm(emb1) + np.linalg.norm(emb2)) / 2\n",
257
+ "\n",
258
+ " # Total \"Hamiltonian\" energy: balance between change and relevance\n",
259
+ " H = T - V\n",
260
+ "\n",
261
+ " return H, T, V\n",
262
+ "\n",
263
+ "\n",
264
+ "def compute_trajectory_energy(trajectory):\n",
265
+ " return refined_hamiltonian_energy({'Fact1': str(trajectory[0]), 'Fact2': str(trajectory[1])})[0]\n",
266
+ "\n",
267
+ "\n",
268
+ "# Compute trajectories for all chains\n",
269
+ "trajectories = df.apply(get_trajectory, axis=1)\n",
270
+ "\n",
271
+ "# Compute energies for trajectories\n",
272
+ "trajectory_energies = trajectories.apply(compute_trajectory_energy)\n"
273
+ ],
274
+ "metadata": {
275
+ "id": "yveIXutUX3ub"
276
+ },
277
+ "execution_count": null,
278
+ "outputs": []
279
+ },
280
+ {
281
+ "cell_type": "code",
282
+ "source": [
283
+ "# Use PCA to reduce dimensionality for visualization\n",
284
+ "pca = PCA(n_components=3)\n",
285
+ "all_points = np.vstack(trajectories.values)\n",
286
+ "pca_result = pca.fit_transform(all_points)\n",
287
+ "\n",
288
+ "trajectories_3d = trajectories.apply(lambda t: pca.transform(t))\n",
289
+ "\n",
290
+ "\n",
291
+ "# Analyze trajectory properties\n",
292
+ "def trajectory_length(traj):\n",
293
+ " return np.sum(np.sqrt(np.sum(np.diff(traj, axis=0)**2, axis=1)))\n",
294
+ "\n",
295
+ "def trajectory_smoothness(traj):\n",
296
+ " first = abs(np.diff(traj[0], axis=0))[0]\n",
297
+ " second = abs(np.diff(traj[1], axis=0))[0]\n",
298
+ " return (first + second)/2\n",
299
+ "\n",
300
+ "traj_properties = pd.DataFrame({\n",
301
+ " 'length': trajectories_3d.apply(trajectory_length),\n",
302
+ " 'smoothness': trajectories_3d.apply(trajectory_smoothness),\n",
303
+ " 'is_valid': df['is_valid']\n",
304
+ "})\n"
305
+ ],
306
+ "metadata": {
307
+ "id": "qFF7_0TD6JRO"
308
+ },
309
+ "execution_count": null,
310
+ "outputs": []
311
+ },
312
+ {
313
+ "cell_type": "code",
314
+ "source": [
315
+ "# Create the main figure and grid for subplots\n",
316
+ "fig, axs = plt.subplots(2, 2, figsize=(15, 12))\n",
317
+ "fig.suptitle(\"Refined Hamiltonian-Inspired Energy Analysis of Reasoning Chains\", fontsize=16)\n",
318
+ "\n",
319
+ "# Distribution of Hamiltonian Energy\n",
320
+ "sns.histplot(data=df, x='H_energy', ax=axs[0, 0], kde=True, color='blue', bins=50)\n",
321
+ "axs[0, 0].set_title(\"Distribution of Refined Hamiltonian Energy\")\n",
322
+ "axs[0, 0].set_xlabel(\"Hamiltonian Energy\")\n",
323
+ "axs[0, 0].set_ylabel(\"Count\")\n",
324
+ "\n",
325
+ "# Kinetic vs Potential Energy\n",
326
+ "scatter = axs[0, 1].scatter(df['T_energy'], df['V_energy'], c=df['H_energy'], cmap='viridis', s=5, alpha=0.6)\n",
327
+ "axs[0, 1].set_title(\"Refined Kinetic vs Potential Energy\")\n",
328
+ "axs[0, 1].set_xlabel(\"Kinetic Energy (T)\")\n",
329
+ "axs[0, 1].set_ylabel(\"Potential Energy (V)\")\n",
330
+ "plt.colorbar(scatter, ax=axs[0, 1], label=\"Hamiltonian Energy\")\n",
331
+ "\n",
332
+ "# Hamiltonian Energy: Valid vs Invalid Chains\n",
333
+ "valid_chains = df[df['is_valid']]\n",
334
+ "invalid_chains = df[~df['is_valid']]\n",
335
+ "sns.histplot(data=valid_chains, x='H_energy', ax=axs[1, 0], kde=True, color='green', label='Valid Chains', bins=50, alpha=0.6)\n",
336
+ "sns.histplot(data=invalid_chains, x='H_energy', ax=axs[1, 0], kde=True, color='red', label='Invalid Chains', bins=50, alpha=0.6)\n",
337
+ "axs[1, 0].set_title(\"Refined Hamiltonian Energy: Valid vs Invalid Chains\")\n",
338
+ "axs[1, 0].set_xlabel(\"Hamiltonian Energy\")\n",
339
+ "axs[1, 0].set_ylabel(\"Count\")\n",
340
+ "axs[1, 0].legend()\n",
341
+ "\n",
342
+ "# Distribution of Energy Conservation Scores\n",
343
+ "sns.histplot(data=df, x='energy_conservation', ax=axs[1, 1], kde=True, color='orange', bins=50)\n",
344
+ "axs[1, 1].set_title(\"Distribution of Refined Energy Conservation Scores\")\n",
345
+ "axs[1, 1].set_xlabel(\"Energy Conservation Score\")\n",
346
+ "axs[1, 1].set_ylabel(\"Count\")\n",
347
+ "\n",
348
+ "# Adjust layout and display\n",
349
+ "plt.tight_layout()\n",
350
+ "plt.subplots_adjust(top=0.93) # Adjust for main title\n",
351
+ "plt.savefig('refined_hamiltonian_analysis.png', dpi=300, bbox_inches='tight')\n",
352
+ "plt.show()"
353
+ ],
354
+ "metadata": {
355
+ "id": "kqfbA7w3NuPM"
356
+ },
357
+ "execution_count": null,
358
+ "outputs": []
359
+ },
360
+ {
361
+ "cell_type": "code",
362
+ "source": [
363
+ "# Calculate direction vectors\n",
364
+ "def calculate_direction(trajectory):\n",
365
+ " return trajectory[1] - trajectory[0]\n",
366
+ "\n",
367
+ "direction_vectors = np.array([calculate_direction(traj) for traj in trajectories_3d])\n",
368
+ "\n",
369
+ "# Calculate magnitude and angle of direction vectors\n",
370
+ "magnitudes = np.linalg.norm(direction_vectors, axis=1)\n",
371
+ "angles = np.arctan2(direction_vectors[:, 1], direction_vectors[:, 0])\n",
372
+ "\n",
373
+ "# Add these to the dataframe\n",
374
+ "df['trajectory_magnitude'] = magnitudes\n",
375
+ "df['trajectory_angle'] = angles\n",
376
+ "\n",
377
+ "# Visualize magnitude distribution\n",
378
+ "plt.figure(figsize=(12, 6))\n",
379
+ "sns.histplot(data=df, x='trajectory_magnitude', hue='is_valid', element='step', stat='density', common_norm=False)\n",
380
+ "plt.title('Distribution of Trajectory Magnitudes')\n",
381
+ "plt.xlabel('Magnitude')\n",
382
+ "plt.ylabel('Density')\n",
383
+ "plt.legend(title='Is Valid')\n",
384
+ "plt.tight_layout()\n",
385
+ "plt.tight_layout()\n",
386
+ "plt.savefig('trajectories_magntude_plot.png', dpi=300, bbox_inches='tight')\n",
387
+ "plt.show()"
388
+ ],
389
+ "metadata": {
390
+ "id": "tYVhJJbPwNxo"
391
+ },
392
+ "execution_count": null,
393
+ "outputs": []
394
+ },
395
+ {
396
+ "cell_type": "code",
397
+ "source": [
398
+ "plt.figure(figsize=(12, 6))\n",
399
+ "\n",
400
+ "# Define colors explicitly\n",
401
+ "colors = {'Valid': 'blue', 'Invalid': 'red'}\n",
402
+ "\n",
403
+ "# Create a new DataFrame with the data for plotting\n",
404
+ "plot_data = pd.DataFrame({\n",
405
+ " 'Hamiltonian Energy': df['H_energy'],\n",
406
+ " 'Validity': df['is_valid'].map({True: 'Valid', False: 'Invalid'})\n",
407
+ "})\n",
408
+ "\n",
409
+ "# Create the histogram plot with explicit colors\n",
410
+ "sns.histplot(data=plot_data, x='Hamiltonian Energy', hue='Validity',\n",
411
+ " element='step', stat='density', common_norm=False,\n",
412
+ " palette=colors)\n",
413
+ "\n",
414
+ "plt.title('Distribution of Refined Hamiltonian Energy', fontsize=16)\n",
415
+ "plt.xlabel('Hamiltonian Energy', fontsize=14)\n",
416
+ "plt.ylabel('Density', fontsize=14)\n",
417
+ "\n",
418
+ "# Adjust legend\n",
419
+ "plt.legend(title='Chain Validity', title_fontsize='13', fontsize='12')\n",
420
+ "\n",
421
+ "# Add vertical lines for mean energies\n",
422
+ "plt.axvline(x=-60.889, color='blue', linestyle='--', label='Mean Valid')\n",
423
+ "plt.axvline(x=-53.816, color='red', linestyle='--', label='Mean Invalid')\n",
424
+ "\n",
425
+ "# Add text annotations for mean energies\n",
426
+ "plt.text(-60.889, plt.gca().get_ylim()[1], 'Mean Valid',\n",
427
+ " rotation=90, va='top', ha='right', color='blue')\n",
428
+ "plt.text(-53.816, plt.gca().get_ylim()[1], 'Mean Invalid',\n",
429
+ " rotation=90, va='top', ha='left', color='red')\n",
430
+ "\n",
431
+ "plt.tight_layout()\n",
432
+ "plt.savefig('refined_hamiltonian_energy_distribution.png', dpi=300, bbox_inches='tight')\n",
433
+ "plt.show()"
434
+ ],
435
+ "metadata": {
436
+ "id": "m1fHZ-NpMnHD"
437
+ },
438
+ "execution_count": null,
439
+ "outputs": []
440
+ },
441
+ {
442
+ "cell_type": "code",
443
+ "source": [
444
+ "# Perform PCA to reduce to 2 dimensions\n",
445
+ "pca = PCA(n_components=2)\n",
446
+ "trajectories_2d = pca.fit_transform(np.vstack(trajectories))\n",
447
+ "\n",
448
+ "# Reshape the data back into trajectories\n",
449
+ "trajectories_2d = trajectories_2d.reshape(len(trajectories), -1, 2)\n",
450
+ "\n",
451
+ "# Create the plot\n",
452
+ "plt.figure(figsize=(12, 10))\n",
453
+ "plt.style.use('seaborn-whitegrid')\n",
454
+ "sns.set_context(\"paper\")\n",
455
+ "plt.rcParams['font.family'] = 'serif'\n",
456
+ "plt.rcParams['font.serif'] = ['Times New Roman'] + plt.rcParams['font.serif']\n",
457
+ "\n",
458
+ "# Plot trajectories\n",
459
+ "valid_trajectories = []\n",
460
+ "invalid_trajectories = []\n",
461
+ "for i, traj in enumerate(trajectories_2d[:100]): # Limit to 100 for clarity\n",
462
+ " if df.iloc[i]['is_valid']:\n",
463
+ " valid_trajectories.append(traj)\n",
464
+ " color = 'green'\n",
465
+ " else:\n",
466
+ " invalid_trajectories.append(traj)\n",
467
+ " color = 'red'\n",
468
+ " plt.plot(traj[:, 0], traj[:, 1], color=color, alpha=0.5)\n",
469
+ " plt.scatter(traj[0, 0], traj[0, 1], color=color, s=20, marker='o')\n",
470
+ " plt.scatter(traj[-1, 0], traj[-1, 1], color=color, s=20, marker='s')\n",
471
+ "\n",
472
+ "# Calculate the vector field based on the average direction of trajectories\n",
473
+ "grid_size = 20\n",
474
+ "x = np.linspace(trajectories_2d[:, :, 0].min(), trajectories_2d[:, :, 0].max(), grid_size)\n",
475
+ "y = np.linspace(trajectories_2d[:, :, 1].min(), trajectories_2d[:, :, 1].max(), grid_size)\n",
476
+ "X, Y = np.meshgrid(x, y)\n",
477
+ "\n",
478
+ "U = np.zeros_like(X)\n",
479
+ "V = np.zeros_like(Y)\n",
480
+ "\n",
481
+ "for i in range(grid_size):\n",
482
+ " for j in range(grid_size):\n",
483
+ " nearby_trajectories = [traj for traj in trajectories_2d if\n",
484
+ " (x[i]-0.5 < traj[:, 0]).any() and (traj[:, 0] < x[i]+0.5).any() and\n",
485
+ " (y[j]-0.5 < traj[:, 1]).any() and (traj[:, 1] < y[j]+0.5).any()]\n",
486
+ " if nearby_trajectories:\n",
487
+ " directions = np.diff(nearby_trajectories, axis=1)\n",
488
+ " avg_direction = np.mean(directions, axis=(0, 1))\n",
489
+ " U[j, i], V[j, i] = avg_direction\n",
490
+ "\n",
491
+ "# Normalize the vector field\n",
492
+ "magnitude = np.sqrt(U**2 + V**2)\n",
493
+ "U = U / np.where(magnitude > 0, magnitude, 1)\n",
494
+ "V = V / np.where(magnitude > 0, magnitude, 1)\n",
495
+ "\n",
496
+ "plt.streamplot(X, Y, U, V, density=1, color='gray', linewidth=0.5, arrowsize=0.5)\n",
497
+ "\n",
498
+ "# Find key points using KMeans clustering\n",
499
+ "n_clusters = 5 # Adjust this number based on how many key points you want\n",
500
+ "kmeans = KMeans(n_clusters=n_clusters)\n",
501
+ "flattened_trajectories = trajectories_2d.reshape(-1, 2)\n",
502
+ "kmeans.fit(flattened_trajectories)\n",
503
+ "key_points = kmeans.cluster_centers_\n",
504
+ "\n",
505
+ "# Plot key points\n",
506
+ "plt.scatter(key_points[:, 0], key_points[:, 1], color='blue', s=100, marker='*', zorder=5)\n",
507
+ "\n",
508
+ "# Add labels to key points\n",
509
+ "for i, point in enumerate(key_points):\n",
510
+ " plt.annotate(f'Key Point {i+1}', (point[0], point[1]), xytext=(5, 5),\n",
511
+ " textcoords='offset points', fontsize=8, color='blue')\n",
512
+ "\n",
513
+ "# Add labels and title\n",
514
+ "plt.xlabel('PCA 1')\n",
515
+ "plt.ylabel('PCA 2')\n",
516
+ "plt.title('2D Reasoning Trajectories with Phase Space Features and Key Points')\n",
517
+ "\n",
518
+ "# Add a legend\n",
519
+ "valid_line = plt.Line2D([], [], color='green', label='Valid Chains')\n",
520
+ "invalid_line = plt.Line2D([], [], color='red', label='Invalid Chains')\n",
521
+ "vector_field_line = plt.Line2D([], [], color='gray', label='Vector Field')\n",
522
+ "key_point_marker = plt.Line2D([], [], color='blue', marker='*', linestyle='None',\n",
523
+ " markersize=10, label='Key Points')\n",
524
+ "plt.legend(handles=[valid_line, invalid_line, vector_field_line, key_point_marker])\n",
525
+ "\n",
526
+ "# Show the plot\n",
527
+ "plt.tight_layout()\n",
528
+ "plt.savefig('2d_reasoning_trajectories_with_key_points.png', dpi=300, bbox_inches='tight')\n",
529
+ "plt.show()"
530
+ ],
531
+ "metadata": {
532
+ "id": "m38JkWLcQKCc"
533
+ },
534
+ "execution_count": null,
535
+ "outputs": []
536
+ },
537
+ {
538
+ "cell_type": "code",
539
+ "source": [
540
+ "fig = plt.figure(figsize=(10, 8))\n",
541
+ "ax = fig.add_subplot(111, projection='3d')\n",
542
+ "\n",
543
+ "for i, trajectory in enumerate(trajectories_3d[:100]): # Limit to first 100 for clarity\n",
544
+ " color = 'green' if df.iloc[i]['is_valid'] else 'red'\n",
545
+ " ax.plot(trajectory[:, 0], trajectory[:, 1], trajectory[:, 2], color=color, alpha=0.5)\n",
546
+ " ax.scatter(trajectory[0, 0], trajectory[0, 1], trajectory[0, 2], color=color, s=20)\n",
547
+ " ax.scatter(trajectory[-1, 0], trajectory[-1, 1], trajectory[-1, 2], color=color, s=20, marker='s')\n",
548
+ "\n",
549
+ "ax.set_xlabel('PCA 1')\n",
550
+ "ax.set_ylabel('PCA 2')\n",
551
+ "ax.set_zlabel('PCA 3')\n",
552
+ "ax.set_title('Reasoning Trajectories in 3D Embedding Space')\n",
553
+ "plt.tight_layout()\n",
554
+ "plt.show()"
555
+ ],
556
+ "metadata": {
557
+ "id": "nVVADjWNNVy_"
558
+ },
559
+ "execution_count": null,
560
+ "outputs": []
561
+ },
562
+ {
563
+ "cell_type": "code",
564
+ "source": [
565
+ "def compute_vector_field(trajectories, grid_size=10):\n",
566
+ " # Determine the bounds of the space\n",
567
+ " all_points = np.vstack(trajectories)\n",
568
+ " mins = np.min(all_points, axis=0)\n",
569
+ " maxs = np.max(all_points, axis=0)\n",
570
+ "\n",
571
+ " # Create a grid\n",
572
+ " x = np.linspace(mins[0], maxs[0], grid_size)\n",
573
+ " y = np.linspace(mins[1], maxs[1], grid_size)\n",
574
+ " z = np.linspace(mins[2], maxs[2], grid_size)\n",
575
+ " X, Y, Z = np.meshgrid(x, y, z)\n",
576
+ "\n",
577
+ " U = np.zeros((grid_size, grid_size, grid_size))\n",
578
+ " V = np.zeros((grid_size, grid_size, grid_size))\n",
579
+ " W = np.zeros((grid_size, grid_size, grid_size))\n",
580
+ "\n",
581
+ " # Compute average direction for each grid cell\n",
582
+ " for trajectory in trajectories:\n",
583
+ " directions = np.diff(trajectory, axis=0)\n",
584
+ " for direction, point in zip(directions, trajectory[:-1]):\n",
585
+ " i, j, k = np.floor((point - mins) / (maxs - mins) * (grid_size - 1)).astype(int)\n",
586
+ " U[i, j, k] += direction[0]\n",
587
+ " V[i, j, k] += direction[1]\n",
588
+ " W[i, j, k] += direction[2]\n",
589
+ "\n",
590
+ " # Normalize\n",
591
+ " magnitude = np.sqrt(U**2 + V**2 + W**2)\n",
592
+ " U /= np.where(magnitude > 0, magnitude, 1)\n",
593
+ " V /= np.where(magnitude > 0, magnitude, 1)\n",
594
+ " W /= np.where(magnitude > 0, magnitude, 1)\n",
595
+ "\n",
596
+ " return X, Y, Z, U, V, W\n",
597
+ "\n",
598
+ "# Set up the figure and 3D axis\n",
599
+ "fig = plt.figure(figsize=(12, 10))\n",
600
+ "ax = fig.add_subplot(111, projection='3d')\n",
601
+ "\n",
602
+ "# Plot trajectories\n",
603
+ "for i, trajectory in enumerate(trajectories_3d[:100]): # Limit to first 100 for clarity\n",
604
+ " color = 'green' if df.iloc[i]['is_valid'] else 'red'\n",
605
+ " ax.plot(trajectory[:, 0], trajectory[:, 1], trajectory[:, 2], color=color, alpha=0.5)\n",
606
+ " ax.scatter(trajectory[0, 0], trajectory[0, 1], trajectory[0, 2], color=color, s=20)\n",
607
+ " ax.scatter(trajectory[-1, 0], trajectory[-1, 1], trajectory[-1, 2], color=color, s=20, marker='s')\n",
608
+ "\n",
609
+ "# Compute and plot vector field\n",
610
+ "X, Y, Z, U, V, W = compute_vector_field(trajectories_3d[:100])\n",
611
+ "ax.quiver(X, Y, Z, U, V, W, length=0.5, normalize=True, color='blue', alpha=0.3)\n",
612
+ "\n",
613
+ "ax.set_xlabel('PCA 1')\n",
614
+ "ax.set_ylabel('PCA 2')\n",
615
+ "ax.set_zlabel('PCA 3')\n",
616
+ "ax.set_title('Reasoning Trajectories and Phase Space in 3D Embedding Space')\n",
617
+ "\n",
618
+ "plt.tight_layout()\n",
619
+ "plt.savefig('3d_phase_space_plot.png', dpi=300, bbox_inches='tight')\n",
620
+ "plt.show()"
621
+ ],
622
+ "metadata": {
623
+ "id": "l0UmPM8xftuv"
624
+ },
625
+ "execution_count": null,
626
+ "outputs": []
627
+ },
628
+ {
629
+ "cell_type": "code",
630
+ "source": [
631
+ "plt.figure(figsize=(10, 6))\n",
632
+ "\n",
633
+ "# Create the histogram plot\n",
634
+ "sns.histplot(data=df, x='energy_conservation', kde=True, bins=50, color='green')\n",
635
+ "\n",
636
+ "# Set the title and labels\n",
637
+ "plt.title(\"Distribution of Energy Conservation Scores\", fontsize=16)\n",
638
+ "plt.xlabel(\"Energy Conservation Score\", fontsize=12)\n",
639
+ "plt.ylabel(\"Frequency\", fontsize=12)\n",
640
+ "\n",
641
+ "# Adjust layout and display\n",
642
+ "plt.tight_layout()\n",
643
+ "plt.savefig('energy_conservation_distribution.png', dpi=300, bbox_inches='tight')\n",
644
+ "plt.show()"
645
+ ],
646
+ "metadata": {
647
+ "id": "qca1p7PhOaU6"
648
+ },
649
+ "execution_count": null,
650
+ "outputs": []
651
+ },
652
+ {
653
+ "cell_type": "code",
654
+ "source": [
655
+ "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))\n",
656
+ "\n",
657
+ "sns.histplot(data=df, x='trajectory_magnitude', hue='is_valid', element='step', stat='density', common_norm=False, ax=ax1)\n",
658
+ "ax1.set_title('Distribution of Trajectory Magnitudes')\n",
659
+ "ax1.set_xlabel('Magnitude')\n",
660
+ "ax1.set_ylabel('Density')\n",
661
+ "\n",
662
+ "sns.histplot(data=df, x='trajectory_angle', hue='is_valid', element='step', stat='density', common_norm=False, ax=ax2)\n",
663
+ "ax2.set_title('Distribution of Trajectory Angles')\n",
664
+ "ax2.set_xlabel('Angle (radians)')\n",
665
+ "ax2.set_ylabel('Density')\n",
666
+ "\n",
667
+ "plt.tight_layout()\n",
668
+ "plt.savefig('magnitude_angle_distribution.png', dpi=300, bbox_inches='tight')\n",
669
+ "plt.close()"
670
+ ],
671
+ "metadata": {
672
+ "id": "I8VrMb6MMsOc"
673
+ },
674
+ "execution_count": null,
675
+ "outputs": []
676
+ },
677
+ {
678
+ "cell_type": "code",
679
+ "source": [
680
+ "# Additional analysis\n",
681
+ "print(f\"Average Energy Conservation Score: {df['energy_conservation'].mean():.4f}\")\n",
682
+ "print(f\"Correlation between Energy Conservation and Validity: {df['energy_conservation'].corr(df['is_valid']):.4f}\")\n",
683
+ "print(f\"Average Hamiltonian Energy for Valid Chains: {valid_chains['H_energy'].mean():.4f}\")\n",
684
+ "print(f\"Average Hamiltonian Energy for Invalid Chains: {invalid_chains['H_energy'].mean():.4f}\")\n",
685
+ "\n",
686
+ "# T-test for difference in Hamiltonian Energy\n",
687
+ "t_stat, p_value = stats.ttest_ind(valid_chains['H_energy'], invalid_chains['H_energy'])\n",
688
+ "print(f\"\\nT-test for difference in Hamiltonian Energy:\")\n",
689
+ "print(f\"t-statistic: {t_stat:.4f}\")\n",
690
+ "print(f\"p-value: {p_value:.4f}\")"
691
+ ],
692
+ "metadata": {
693
+ "id": "FHmMSmNAI-qc"
694
+ },
695
+ "execution_count": null,
696
+ "outputs": []
697
+ },
698
+ {
699
+ "cell_type": "markdown",
700
+ "source": [
701
+ "## Geometric analysis"
702
+ ],
703
+ "metadata": {
704
+ "id": "1s_DosZEWVhy"
705
+ }
706
+ },
707
+ {
708
+ "cell_type": "code",
709
+ "source": [
710
+ "fig = plt.figure(figsize=(10, 8))\n",
711
+ "ax = fig.add_subplot(111, projection='3d')\n",
712
+ "\n",
713
+ "for i, trajectory in enumerate(trajectories_3d[:100]): # Limit to first 100 for clarity\n",
714
+ " color = 'green' if df.iloc[i]['is_valid'] else 'red'\n",
715
+ " ax.plot(trajectory[:, 0], trajectory[:, 1], trajectory[:, 2], color=color, alpha=0.5)\n",
716
+ " ax.scatter(trajectory[0, 0], trajectory[0, 1], trajectory[0, 2], color=color, s=20)\n",
717
+ " ax.scatter(trajectory[-1, 0], trajectory[-1, 1], trajectory[-1, 2], color=color, s=20, marker='s')\n",
718
+ "\n",
719
+ "ax.set_xlabel('PCA 1')\n",
720
+ "ax.set_ylabel('PCA 2')\n",
721
+ "ax.set_zlabel('PCA 3')\n",
722
+ "ax.set_title('Reasoning Trajectories in 3D Embedding Space')\n",
723
+ "plt.tight_layout()\n",
724
+ "plt.savefig('3d_trajectories.png', dpi=300, bbox_inches='tight')\n",
725
+ "plt.close()\n",
726
+ "\n",
727
+ "# 2. Trajectory Energy by Chain Index\n",
728
+ "plt.figure(figsize=(10, 6))\n",
729
+ "sns.scatterplot(x=df.index, y=trajectory_energies, hue=df['is_valid'], palette={True: 'green', False: 'red'})\n",
730
+ "plt.title('Trajectory Energy by Chain Index')\n",
731
+ "plt.xlabel('Chain Index')\n",
732
+ "plt.ylabel('Energy')\n",
733
+ "plt.legend(title='Is Valid')\n",
734
+ "plt.tight_layout()\n",
735
+ "plt.savefig('trajectory_energy.png', dpi=300, bbox_inches='tight')\n",
736
+ "plt.close()"
737
+ ],
738
+ "metadata": {
739
+ "id": "2Sz-nqGA9p8B"
740
+ },
741
+ "execution_count": null,
742
+ "outputs": []
743
+ },
744
+ {
745
+ "cell_type": "code",
746
+ "source": [
747
+ "# Energy Plot\n",
748
+ "plt.figure(figsize=(12, 6))\n",
749
+ "sns.scatterplot(x=df.index, y=trajectory_energies, hue=df['is_valid'], palette={True: 'green', False: 'red'})\n",
750
+ "plt.title('Trajectory Energy by Chain Index')\n",
751
+ "plt.xlabel('Chain Index')\n",
752
+ "plt.ylabel('Energy')\n",
753
+ "plt.legend(title='Is Valid')\n",
754
+ "plt.tight_layout()\n",
755
+ "plt.show()"
756
+ ],
757
+ "metadata": {
758
+ "id": "5rN0K7tM_68P"
759
+ },
760
+ "execution_count": null,
761
+ "outputs": []
762
+ },
763
+ {
764
+ "cell_type": "code",
765
+ "source": [
766
+ "plt.figure(figsize=(12, 6))\n",
767
+ "\n",
768
+ "# Define colors explicitly\n",
769
+ "colors = {'Valid': 'green', 'Invalid': 'red'}\n",
770
+ "\n",
771
+ "# Create the histogram plot with explicit colors\n",
772
+ "sns.histplot(data=pd.DataFrame({'Energy': trajectory_energies, 'Is Valid': df['is_valid'].map({True: 'Valid', False: 'Invalid'})}),\n",
773
+ " x='Energy', hue='Is Valid', element='step', stat='density', common_norm=False,\n",
774
+ " palette=colors)\n",
775
+ "\n",
776
+ "plt.title('Distribution of Trajectory Energies', fontsize=16)\n",
777
+ "plt.xlabel('Energy', fontsize=14)\n",
778
+ "plt.ylabel('Density', fontsize=14)\n",
779
+ "\n",
780
+ "# Create a custom legend\n",
781
+ "handles = [plt.Rectangle((0,0),1,1, color=color) for color in colors.values()]\n",
782
+ "labels = list(colors.keys())\n",
783
+ "plt.legend(handles, labels, title='Trajectory Validity', title_fontsize='13', fontsize='12')\n",
784
+ "\n",
785
+ "plt.tight_layout()\n",
786
+ "plt.savefig('energy_distribution_plot.png', dpi=300, bbox_inches='tight')\n",
787
+ "plt.show()"
788
+ ],
789
+ "metadata": {
790
+ "id": "iRG8GKRF__3a"
791
+ },
792
+ "execution_count": null,
793
+ "outputs": []
794
+ },
795
+ {
796
+ "cell_type": "code",
797
+ "source": [
798
+ "# Distribution of Trajectory Magnitudes and Angles\n",
799
+ "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))\n",
800
+ "\n",
801
+ "sns.histplot(data=df, x='trajectory_magnitude', hue='is_valid', element='step', stat='density', common_norm=False, ax=ax1)\n",
802
+ "ax1.set_title('Distribution of Trajectory Magnitudes')\n",
803
+ "ax1.set_xlabel('Magnitude')\n",
804
+ "ax1.set_ylabel('Density')\n",
805
+ "\n",
806
+ "sns.histplot(data=df, x='trajectory_angle', hue='is_valid', element='step', stat='density', common_norm=False, ax=ax2)\n",
807
+ "ax2.set_title('Distribution of Trajectory Angles')\n",
808
+ "ax2.set_xlabel('Angle (radians)')\n",
809
+ "ax2.set_ylabel('Density')\n",
810
+ "\n",
811
+ "plt.tight_layout()\n",
812
+ "plt.savefig('magnitude_angle_distribution.png', dpi=300, bbox_inches='tight')\n",
813
+ "plt.close()"
814
+ ],
815
+ "metadata": {
816
+ "id": "yLJie7VYoas6"
817
+ },
818
+ "execution_count": null,
819
+ "outputs": []
820
+ },
821
+ {
822
+ "cell_type": "code",
823
+ "source": [
824
+ "# Trajectory Magnitude vs Angle\n",
825
+ "plt.figure(figsize=(10, 8))\n",
826
+ "sns.scatterplot(data=df, x='trajectory_angle', y='trajectory_magnitude', hue='is_valid', alpha=0.6)\n",
827
+ "plt.title('Trajectory Magnitude vs Angle')\n",
828
+ "plt.xlabel('Angle (radians)')\n",
829
+ "plt.ylabel('Magnitude')\n",
830
+ "plt.legend(title='Is Valid')\n",
831
+ "plt.tight_layout()\n",
832
+ "plt.savefig('magnitude_vs_angle.png', dpi=300, bbox_inches='tight')\n",
833
+ "plt.close()\n",
834
+ "\n",
835
+ "# 6. Trajectory Properties Comparison\n",
836
+ "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))\n",
837
+ "\n",
838
+ "sns.boxplot(x='is_valid', y='length', data=traj_properties, ax=ax1)\n",
839
+ "ax1.set_title('Trajectory Length')\n",
840
+ "ax1.set_xlabel('Is Valid')\n",
841
+ "ax1.set_ylabel('Length')\n",
842
+ "\n",
843
+ "sns.boxplot(x='is_valid', y='smoothness', data=traj_properties, ax=ax2)\n",
844
+ "ax2.set_title('Trajectory Smoothness')\n",
845
+ "ax2.set_xlabel('Is Valid')\n",
846
+ "ax2.set_ylabel('Smoothness')\n",
847
+ "\n",
848
+ "plt.tight_layout()\n",
849
+ "plt.savefig('trajectory_properties.png', dpi=300, bbox_inches='tight')\n",
850
+ "plt.close()"
851
+ ],
852
+ "metadata": {
853
+ "id": "OOasgefio41H"
854
+ },
855
+ "execution_count": null,
856
+ "outputs": []
857
+ },
858
+ {
859
+ "cell_type": "code",
860
+ "source": [
861
+ "plt.figure(figsize=(12, 8))\n",
862
+ "\n",
863
+ "# Define colors explicitly\n",
864
+ "colors = {'Valid': 'blue', 'Invalid': 'red'}\n",
865
+ "\n",
866
+ "# Prepare the data\n",
867
+ "plot_data = df.copy()\n",
868
+ "plot_data['Validity'] = df['is_valid'].map({True: 'Valid', False: 'Invalid'})\n",
869
+ "\n",
870
+ "# Create the scatter plot with explicit colors\n",
871
+ "sns.scatterplot(data=plot_data, x='trajectory_angle', y='trajectory_magnitude', hue='Validity',\n",
872
+ " palette=colors, alpha=0.6)\n",
873
+ "\n",
874
+ "plt.title('Trajectory Magnitude vs Angle', fontsize=16)\n",
875
+ "plt.xlabel('Angle (radians)', fontsize=14)\n",
876
+ "plt.ylabel('Magnitude', fontsize=14)\n",
877
+ "\n",
878
+ "# Create custom legend handles\n",
879
+ "handles = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color, markersize=10, alpha=0.6)\n",
880
+ " for color in colors.values()]\n",
881
+ "labels = list(colors.keys())\n",
882
+ "\n",
883
+ "# Add the legend with custom handles\n",
884
+ "plt.legend(handles, labels, title='Chain Validity', title_fontsize='13', fontsize='12')\n",
885
+ "\n",
886
+ "plt.tight_layout()\n",
887
+ "plt.savefig('refined_magnitude_vs_angle_plot.png', dpi=300, bbox_inches='tight')\n",
888
+ "plt.show()\n",
889
+ "\n",
890
+ "# Calculate and print statistical information\n",
891
+ "valid_data = df[df['is_valid']]\n",
892
+ "invalid_data = df[~df['is_valid']]\n",
893
+ "\n",
894
+ "print(\"Statistical Information:\")\n",
895
+ "print(f\"Correlation between Angle and Magnitude (overall): {df['trajectory_angle'].corr(df['trajectory_magnitude']):.3f}\")\n",
896
+ "print(f\"Correlation for Valid Chains: {valid_data['trajectory_angle'].corr(valid_data['trajectory_magnitude']):.3f}\")\n",
897
+ "print(f\"Correlation for Invalid Chains: {invalid_data['trajectory_angle'].corr(invalid_data['trajectory_magnitude']):.3f}\")\n",
898
+ "\n",
899
+ "# Perform t-tests\n",
900
+ "t_stat_angle, p_value_angle = stats.ttest_ind(valid_data['trajectory_angle'], invalid_data['trajectory_angle'])\n",
901
+ "t_stat_mag, p_value_mag = stats.ttest_ind(valid_data['trajectory_magnitude'], invalid_data['trajectory_magnitude'])\n",
902
+ "\n",
903
+ "print(\"\\nT-test for difference in Trajectory Angle:\")\n",
904
+ "print(f\"t-statistic: {t_stat_angle:.4f}\")\n",
905
+ "print(f\"p-value: {p_value_angle:.4f}\")\n",
906
+ "\n",
907
+ "print(\"\\nT-test for difference in Trajectory Magnitude:\")\n",
908
+ "print(f\"t-statistic: {t_stat_mag:.4f}\")\n",
909
+ "print(f\"p-value: {p_value_mag:.4f}\")\n",
910
+ "\n",
911
+ "# Calculate and print mean values\n",
912
+ "print(\"\\nMean Values:\")\n",
913
+ "print(f\"Mean Angle for Valid Chains: {valid_data['trajectory_angle'].mean():.3f}\")\n",
914
+ "print(f\"Mean Angle for Invalid Chains: {invalid_data['trajectory_angle'].mean():.3f}\")\n",
915
+ "print(f\"Mean Magnitude for Valid Chains: {valid_data['trajectory_magnitude'].mean():.3f}\")\n",
916
+ "print(f\"Mean Magnitude for Invalid Chains: {invalid_data['trajectory_magnitude'].mean():.3f}\")"
917
+ ],
918
+ "metadata": {
919
+ "id": "6pBMYGiKBR7f"
920
+ },
921
+ "execution_count": null,
922
+ "outputs": []
923
+ },
924
+ {
925
+ "cell_type": "code",
926
+ "source": [
927
+ "# Statistical tests\n",
928
+ "valid_mag = df[df['is_valid']]['trajectory_magnitude']\n",
929
+ "invalid_mag = df[~df['is_valid']]['trajectory_magnitude']\n",
930
+ "mag_ttest = ttest_ind(valid_mag, invalid_mag)\n",
931
+ "\n",
932
+ "valid_ang = df[df['is_valid']]['trajectory_angle']\n",
933
+ "invalid_ang = df[~df['is_valid']]['trajectory_angle']\n",
934
+ "ang_ttest = ttest_ind(valid_ang, invalid_ang)\n",
935
+ "\n",
936
+ "print(\"T-test for trajectory magnitude:\", mag_ttest)\n",
937
+ "print(\"T-test for trajectory angle:\", ang_ttest)\n",
938
+ "\n",
939
+ "# Correlation with energy\n",
940
+ "mag_energy_corr = df['trajectory_magnitude'].corr(df['H_energy'])\n",
941
+ "ang_energy_corr = df['trajectory_angle'].corr(df['H_energy'])\n",
942
+ "\n",
943
+ "print(\"Correlation between magnitude and H energy:\", mag_energy_corr)\n",
944
+ "print(\"Correlation between angle and H energy:\", ang_energy_corr)"
945
+ ],
946
+ "metadata": {
947
+ "id": "i2ccr--MBXYa"
948
+ },
949
+ "execution_count": null,
950
+ "outputs": []
951
+ },
952
+ {
953
+ "cell_type": "code",
954
+ "source": [
955
+ "def calculate_curvature(trajectory):\n",
956
+ " # Assuming trajectory has 3 points: start, middle, end\n",
957
+ "\n",
958
+ " a = np.linalg.norm(trajectory[0][1] - trajectory[0][0])\n",
959
+ " b = np.linalg.norm(trajectory[0][2] - trajectory[0][1])\n",
960
+ " c = np.linalg.norm(trajectory[0][2] - trajectory[0][0])\n",
961
+ "\n",
962
+ " s = (a + b + c) / 2\n",
963
+ " area = np.sqrt(s * (s-a) * (s-b) * (s-c))\n",
964
+ "\n",
965
+ " return 4 * area / (a * b * c)\n",
966
+ "\n",
967
+ "def calculate_rate_of_change(trajectory):\n",
968
+ " # Calculate the rate of change between each pair of consecutive points\n",
969
+ " changes = np.diff(trajectory, axis=0)\n",
970
+ " rates = np.linalg.norm(changes, axis=1)\n",
971
+ " return np.mean(rates)\n",
972
+ "\n",
973
+ "# Calculate curvature and rate of change\n",
974
+ "curvatures = []\n",
975
+ "rates_of_change = []\n",
976
+ "\n",
977
+ "for traj in trajectories_3d:\n",
978
+ " curvatures.append(calculate_curvature(traj))\n",
979
+ " rates_of_change.append(calculate_rate_of_change(traj))\n",
980
+ "\n",
981
+ "# Add these to the dataframe\n",
982
+ "df['curvature'] = curvatures\n",
983
+ "df['rate_of_change'] = rates_of_change\n",
984
+ "\n",
985
+ "\n",
986
+ "plt.figure(figsize=(12, 6))\n",
987
+ "\n",
988
+ "# Define colors explicitly\n",
989
+ "colors = {'Valid': 'blue', 'Invalid': 'red'}\n",
990
+ "\n",
991
+ "# Prepare the data\n",
992
+ "plot_data = pd.DataFrame({\n",
993
+ " 'Curvature': df['curvature'],\n",
994
+ " 'Validity': df['is_valid'].map({True: 'Valid', False: 'Invalid'})\n",
995
+ "})\n",
996
+ "\n",
997
+ "# Create the histogram plot with explicit colors\n",
998
+ "sns.histplot(data=plot_data, x='Curvature', hue='Validity',\n",
999
+ " element='step', stat='density', common_norm=False,\n",
1000
+ " palette=colors)\n",
1001
+ "\n",
1002
+ "plt.title('Distribution of Trajectory Curvatures', fontsize=16)\n",
1003
+ "plt.xlabel('Curvature', fontsize=14)\n",
1004
+ "plt.ylabel('Density', fontsize=14)\n",
1005
+ "\n",
1006
+ "# Adjust legend\n",
1007
+ "plt.legend(title='Chain Validity', title_fontsize='13', fontsize='12')\n",
1008
+ "\n",
1009
+ "# Calculate mean curvatures for valid and invalid chains\n",
1010
+ "mean_valid = df[df['is_valid']]['curvature'].mean()\n",
1011
+ "mean_invalid = df[~df['is_valid']]['curvature'].mean()\n",
1012
+ "\n",
1013
+ "# Add vertical lines for mean curvatures\n",
1014
+ "plt.axvline(x=mean_valid, color='blue', linestyle='--', label='Mean Valid')\n",
1015
+ "plt.axvline(x=mean_invalid, color='red', linestyle='--', label='Mean Invalid')\n",
1016
+ "\n",
1017
+ "# Add text annotations for mean curvatures\n",
1018
+ "plt.text(mean_valid, plt.gca().get_ylim()[1], f'Mean Valid: {mean_valid:.3f}',\n",
1019
+ " rotation=90, va='top', ha='right', color='blue')\n",
1020
+ "plt.text(mean_invalid, plt.gca().get_ylim()[1], f'Mean Invalid: {mean_invalid:.3f}',\n",
1021
+ " rotation=90, va='top', ha='left', color='red')\n",
1022
+ "\n",
1023
+ "plt.tight_layout()\n",
1024
+ "plt.savefig('refined_curvature_distribution.png', dpi=300, bbox_inches='tight')\n",
1025
+ "plt.show()\n",
1026
+ "\n",
1027
+ "# Calculate and print statistical information\n",
1028
+ "valid_curv = df[df['is_valid']]['curvature']\n",
1029
+ "invalid_curv = df[~df['is_valid']]['curvature']\n",
1030
+ "t_stat, p_value = stats.ttest_ind(valid_curv, invalid_curv)"
1031
+ ],
1032
+ "metadata": {
1033
+ "id": "BlXQkEKjCrSK"
1034
+ },
1035
+ "execution_count": null,
1036
+ "outputs": []
1037
+ },
1038
+ {
1039
+ "cell_type": "code",
1040
+ "source": [
1041
+ "plt.figure(figsize=(12, 6))\n",
1042
+ "\n",
1043
+ "# Define colors explicitly\n",
1044
+ "colors = {'Valid': 'blue', 'Invalid': 'red'}\n",
1045
+ "\n",
1046
+ "# Prepare the data\n",
1047
+ "plot_data = pd.DataFrame({\n",
1048
+ " 'Rate of Change': df['rate_of_change'],\n",
1049
+ " 'Validity': df['is_valid'].map({True: 'Valid', False: 'Invalid'})\n",
1050
+ "})\n",
1051
+ "\n",
1052
+ "# Create the histogram plot with explicit colors\n",
1053
+ "sns.histplot(data=plot_data, x='Rate of Change', hue='Validity',\n",
1054
+ " element='step', stat='density', common_norm=False,\n",
1055
+ " palette=colors)\n",
1056
+ "\n",
1057
+ "plt.title('Distribution of Trajectory Rates of Change', fontsize=16)\n",
1058
+ "plt.xlabel('Rate of Change', fontsize=14)\n",
1059
+ "plt.ylabel('Density', fontsize=14)\n",
1060
+ "\n",
1061
+ "# Create custom legend handles\n",
1062
+ "handles = [plt.Rectangle((0,0),1,1, color=colors[label]) for label in colors]\n",
1063
+ "labels = list(colors.keys())\n",
1064
+ "\n",
1065
+ "# Add the legend with custom handles\n",
1066
+ "plt.legend(handles, labels, title='Chain Validity', title_fontsize='13', fontsize='12')\n",
1067
+ "\n",
1068
+ "plt.tight_layout()\n",
1069
+ "plt.savefig('simplified_rate_of_change_distribution.png', dpi=300, bbox_inches='tight')\n",
1070
+ "plt.show()\n",
1071
+ "\n",
1072
+ "# Calculate and print statistical information\n",
1073
+ "valid_roc = df[df['is_valid']]['rate_of_change']\n",
1074
+ "invalid_roc = df[~df['is_valid']]['rate_of_change']\n",
1075
+ "t_stat, p_value = stats.ttest_ind(valid_roc, invalid_roc)\n",
1076
+ "\n",
1077
+ "mean_valid = valid_roc.mean()\n",
1078
+ "mean_invalid = invalid_roc.mean()\n",
1079
+ "\n",
1080
+ "print(\"Distribution of Trajectory Rates of Change\")\n",
1081
+ "print(f\"Average Rate of Change for Valid Chains: {mean_valid:.3f}\")\n",
1082
+ "print(f\"Average Rate of Change for Invalid Chains: {mean_invalid:.3f}\")\n",
1083
+ "print(f\"Correlation between Rate of Change and Validity: {df['rate_of_change'].corr(df['is_valid']):.3f}\")\n",
1084
+ "print(\"\\nT-test for difference in Rate of Change:\")\n",
1085
+ "print(f\"t-statistic: {t_stat:.4f}\")\n",
1086
+ "print(f\"p-value: {p_value:.4f}\")"
1087
+ ],
1088
+ "metadata": {
1089
+ "id": "T7GzkWJzCwJe"
1090
+ },
1091
+ "execution_count": null,
1092
+ "outputs": []
1093
+ },
1094
+ {
1095
+ "cell_type": "code",
1096
+ "source": [
1097
+ "# Statistical tests\n",
1098
+ "df['curvature'] = df['curvature'].fillna(0)\n",
1099
+ "df['rate_of_change'] = df['rate_of_change'].astype(float)\n",
1100
+ "valid_curv = df[df['is_valid']]['curvature']\n",
1101
+ "invalid_curv = df[~df['is_valid']]['curvature']\n",
1102
+ "curv_ttest = ttest_ind(valid_curv, invalid_curv)\n",
1103
+ "\n",
1104
+ "valid_roc = df[df['is_valid']]['rate_of_change']\n",
1105
+ "invalid_roc = df[~df['is_valid']]['rate_of_change']\n",
1106
+ "roc_ttest = ttest_ind(valid_roc, invalid_roc)\n",
1107
+ "\n",
1108
+ "print(\"T-test for trajectory curvature:\", curv_ttest)\n",
1109
+ "print(\"T-test for trajectory rate of change:\", roc_ttest)\n",
1110
+ "\n",
1111
+ "# Correlation with energy\n",
1112
+ "curv_energy_corr = df['curvature'].corr(df['H_energy'])\n",
1113
+ "roc_energy_corr = df['rate_of_change'].corr(df['H_energy'])\n",
1114
+ "\n",
1115
+ "print(\"Correlation between curvature and energy:\", curv_energy_corr)\n",
1116
+ "print(\"Correlation between rate of change and energy:\", roc_energy_corr)"
1117
+ ],
1118
+ "metadata": {
1119
+ "id": "0PabrOYpC7dK"
1120
+ },
1121
+ "execution_count": null,
1122
+ "outputs": []
1123
+ },
1124
+ {
1125
+ "cell_type": "code",
1126
+ "source": [
1127
+ "# Frenet's framework\n",
1128
+ "def reduce_dimensionality(trajectories, n_components=3):\n",
1129
+ " \"\"\"Reduce dimensionality of trajectories using PCA\"\"\"\n",
1130
+ " flattened = np.vstack(trajectories)\n",
1131
+ " pca = PCA(n_components=n_components)\n",
1132
+ " reduced = pca.fit_transform(flattened)\n",
1133
+ " return reduced.reshape(len(trajectories), -1, n_components), pca\n",
1134
+ "\n",
1135
+ "def frenet_serret_frame(trajectory):\n",
1136
+ " \"\"\"Compute Frenet-Serret frame for a trajectory\"\"\"\n",
1137
+ " # Compute tangent vectors\n",
1138
+ " T = np.diff(trajectory, axis=0)\n",
1139
+ " T_norm = np.linalg.norm(T, axis=1, keepdims=True)\n",
1140
+ " T = np.divide(T, T_norm, where=T_norm!=0)\n",
1141
+ "\n",
1142
+ " # Compute normal vectors\n",
1143
+ " N = np.diff(T, axis=0)\n",
1144
+ " N_norm = np.linalg.norm(N, axis=1, keepdims=True)\n",
1145
+ " N = np.divide(N, N_norm, where=N_norm!=0)\n",
1146
+ "\n",
1147
+ " # Compute binormal vectors\n",
1148
+ " B = np.cross(T[:-1], N)\n",
1149
+ "\n",
1150
+ " return T[:-1], N, B\n",
1151
+ "\n",
1152
+ "def compute_curvature_torsion(T, N, B):\n",
1153
+ " \"\"\"Compute curvature and torsion from Frenet-Serret frame\"\"\"\n",
1154
+ " dT = np.diff(T, axis=0)\n",
1155
+ " curvature = np.linalg.norm(dT, axis=1)\n",
1156
+ "\n",
1157
+ " # Compute torsion\n",
1158
+ " dB = np.diff(B, axis=0)\n",
1159
+ " torsion = np.sum(dB * N[1:], axis=1)\n",
1160
+ "\n",
1161
+ " return np.mean(curvature), np.mean(torsion)\n",
1162
+ "\n",
1163
+ "# Reduce dimensionality of trajectories\n",
1164
+ "reduced_trajectories, pca = reduce_dimensionality(trajectories)\n",
1165
+ "\n",
1166
+ "# Compute Frenet-Serret frames and curvature/torsion\n",
1167
+ "curvatures = []\n",
1168
+ "torsions = []\n",
1169
+ "for i, traj in enumerate(reduced_trajectories):\n",
1170
+ " try:\n",
1171
+ " T, N, B = frenet_serret_frame(traj)\n",
1172
+ " curvature, torsion = compute_curvature_torsion(T, N, B)\n",
1173
+ " curvatures.append(curvature)\n",
1174
+ " torsions.append(torsion)\n",
1175
+ " except Exception as e:\n",
1176
+ " print(f\"Error processing trajectory {i}: {str(e)}\")\n",
1177
+ " print(f\"Trajectory shape: {traj.shape}\")\n",
1178
+ " curvatures.append(np.nan)\n",
1179
+ " torsions.append(np.nan)\n",
1180
+ "\n",
1181
+ "df['curvature'] = curvatures\n",
1182
+ "df['torsion'] = torsions\n",
1183
+ "\n",
1184
+ "# Remove any NaN values\n",
1185
+ "df = df.dropna(subset=['curvature', 'torsion'])\n"
1186
+ ],
1187
+ "metadata": {
1188
+ "id": "hgpHHxRz438n"
1189
+ },
1190
+ "execution_count": null,
1191
+ "outputs": []
1192
+ },
1193
+ {
1194
+ "cell_type": "code",
1195
+ "source": [
1196
+ "# Analyze the principal components\n",
1197
+ "explained_variance_ratio = pca.explained_variance_ratio_\n",
1198
+ "cumulative_variance_ratio = np.cumsum(explained_variance_ratio)\n",
1199
+ "\n",
1200
+ "plt.figure(figsize=(10, 6))\n",
1201
+ "plt.plot(range(1, len(explained_variance_ratio) + 1), cumulative_variance_ratio, 'bo-')\n",
1202
+ "plt.xlabel('Number of Components', fontsize=14)\n",
1203
+ "plt.ylabel('Cumulative Explained Variance Ratio', fontsize=14)\n",
1204
+ "plt.title('Explained Variance Ratio by Principal Components', fontsize=16)\n",
1205
+ "plt.savefig('pca_explained_variance.png', dpi=300, bbox_inches='tight')\n",
1206
+ "plt.show()\n",
1207
+ "\n",
1208
+ "print(f\"Explained variance ratio of first 3 components: {explained_variance_ratio[:3]}\")\n",
1209
+ "print(f\"Cumulative explained variance ratio of first 3 components: {cumulative_variance_ratio[2]:.4f}\")"
1210
+ ],
1211
+ "metadata": {
1212
+ "id": "UHASmPhm5dsa"
1213
+ },
1214
+ "execution_count": null,
1215
+ "outputs": []
1216
+ },
1217
+ {
1218
+ "cell_type": "code",
1219
+ "source": [
1220
+ "# Compute and visualize Hamiltonian along trajectories\n",
1221
+ "\n",
1222
+ "def hamiltonian(q, p, q_goal):\n",
1223
+ " \"\"\"Hamiltonian function\"\"\"\n",
1224
+ " T = 0.5 * np.dot(p, p) # Kinetic energy\n",
1225
+ " V = sophisticated_potential(q, q_goal) # Potential energy\n",
1226
+ " return T + V\n",
1227
+ "\n",
1228
+ "def sophisticated_potential(q, q_goal):\n",
1229
+ " \"\"\"A more sophisticated potential energy function\"\"\"\n",
1230
+ " similarity = np.dot(q, q_goal) / (np.linalg.norm(q) * np.linalg.norm(q_goal))\n",
1231
+ " complexity = np.linalg.norm(q) # Assume more complex states have higher norm\n",
1232
+ " return -similarity + 0.1 * complexity # Balance between relevance and complexity\n",
1233
+ "\n",
1234
+ "# Compute and visualize Hamiltonian along trajectories\n",
1235
+ "hamiltonians = []\n",
1236
+ "q_goal = np.mean([traj[-1] for traj in trajectories], axis=0) # Assuming the goal is the average final state\n",
1237
+ "\n",
1238
+ "for traj in trajectories:\n",
1239
+ " H = []\n",
1240
+ " for i in range(len(traj)):\n",
1241
+ " q = traj[i]\n",
1242
+ " p = traj[i] - traj[i-1] if i > 0 else np.zeros_like(q) # Estimate momentum as the difference between states\n",
1243
+ " H.append(hamiltonian(q, p, q_goal))\n",
1244
+ " hamiltonians.append(H)\n",
1245
+ "\n",
1246
+ "plt.figure(figsize=(12, 6))\n",
1247
+ "for i, H in enumerate(hamiltonians[:20]): # Plot first 20 for clarity\n",
1248
+ " plt.plot(H, label=f'Trajectory {i+1}')\n",
1249
+ "plt.title('Hamiltonian Evolution Along Reasoning Trajectories', fontsize=16)\n",
1250
+ "plt.xlabel('Time Step', fontsize=16)\n",
1251
+ "plt.ylabel('Hamiltonian',fontsize=16)\n",
1252
+ "plt.legend()\n",
1253
+ "plt.savefig('hamiltonian_evolution_plot.png', dpi=300, bbox_inches='tight')\n",
1254
+ "plt.show()\n",
1255
+ "\n",
1256
+ "# Statistical analysis\n",
1257
+ "valid_curvature = df[df['is_valid']]['curvature']\n",
1258
+ "invalid_curvature = df[~df['is_valid']]['curvature']\n",
1259
+ "t_stat, p_value = stats.ttest_ind(valid_curvature, invalid_curvature)\n",
1260
+ "\n",
1261
+ "print(f\"T-test for curvature: t-statistic = {t_stat}, p-value = {p_value}\")\n",
1262
+ "\n",
1263
+ "# Correlation analysis\n",
1264
+ "correlation = df['curvature'].corr(df['torsion'])\n",
1265
+ "print(f\"Correlation between curvature and torsion: {correlation}\")\n",
1266
+ "\n"
1267
+ ],
1268
+ "metadata": {
1269
+ "id": "v0V1WiVN6F6g"
1270
+ },
1271
+ "execution_count": null,
1272
+ "outputs": []
1273
+ },
1274
+ {
1275
+ "cell_type": "code",
1276
+ "source": [
1277
+ "# 3D plot of trajectories\n",
1278
+ "fig = plt.figure(figsize=(12,12))\n",
1279
+ "ax = fig.add_subplot(111, projection='3d')\n",
1280
+ "\n",
1281
+ "for i, traj in enumerate(trajectories_3d[:20]): # Plot first 20 for clarity\n",
1282
+ " color = 'green' if df.iloc[i]['is_valid'] else 'red'\n",
1283
+ " ax.plot(traj[:, 0], traj[:, 1], traj[:, 2], color=color, alpha=0.6)\n",
1284
+ "\n",
1285
+ "ax.set_xlabel('PCA 1', fontsize=14)\n",
1286
+ "ax.set_ylabel('PCA 2', fontsize=14)\n",
1287
+ "ax.set_zlabel('PCA 3', fontsize=14)\n",
1288
+ "ax.set_title('Reasoning Trajectories in PCA Space', fontsize=16)\n",
1289
+ "# Add legend\n",
1290
+ "ax.legend([valid_handle, invalid_handle], ['Valid', 'Invalid'], loc='upper right')\n",
1291
+ "plt.savefig('pca_trajectories_plot.png', dpi=300, bbox_inches='tight')\n",
1292
+ "plt.show()"
1293
+ ],
1294
+ "metadata": {
1295
+ "id": "7BuXJCesA-2u"
1296
+ },
1297
+ "execution_count": null,
1298
+ "outputs": []
1299
+ },
1300
+ {
1301
+ "cell_type": "code",
1302
+ "source": [
1303
+ "# Statistical Analysis\n",
1304
+ "\n",
1305
+ "pca_means = np.array([traj.mean(axis=0) for traj in trajectories_3d])\n",
1306
+ "X = pd.DataFrame(pca_means, columns=['PCA1', 'PCA2', 'PCA3'])\n",
1307
+ "y = pd.Series(df['is_valid'].values, name='is_valid')\n",
1308
+ "\n",
1309
+ "# Ensure 'is_valid' is boolean\n",
1310
+ "y = y.astype(bool)\n",
1311
+ "\n",
1312
+ "# Combine X and y into a single DataFrame\n",
1313
+ "data = pd.concat([X, y], axis=1)\n",
1314
+ "\n",
1315
+ "# 1. MANOVA test\n",
1316
+ "manova = MANOVA.from_formula('PCA1 + PCA2 + PCA3 ~ is_valid', data=data)\n",
1317
+ "print(\"MANOVA test results:\")\n",
1318
+ "print(manova.mv_test())\n",
1319
+ "\n",
1320
+ "# 2. T-tests for each PCA dimension\n",
1321
+ "for i in range(3):\n",
1322
+ " t_stat, p_value = stats.ttest_ind(X[f'PCA{i+1}'][y], X[f'PCA{i+1}'][~y])\n",
1323
+ " print(f\"T-test for PCA{i+1}: t-statistic = {t_stat:.4f}, p-value = {p_value:.4f}\")\n",
1324
+ "\n",
1325
+ "# 3. Logistic Regression\n",
1326
+ "log_reg = LogisticRegression()\n",
1327
+ "log_reg.fit(X, y)\n",
1328
+ "y_pred = log_reg.predict(X)\n",
1329
+ "accuracy = accuracy_score(y, y_pred)\n",
1330
+ "print(f\"Logistic Regression Accuracy: {accuracy:.4f}\")\n",
1331
+ "\n",
1332
+ "# 4. Effect sizes (Cohen's d) for each PCA dimension\n",
1333
+ "for i in range(3):\n",
1334
+ " cohens_d = (X[f'PCA{i+1}'][y].mean() - X[f'PCA{i+1}'][~y].mean()) / np.sqrt((X[f'PCA{i+1}'][y].var() + X[f'PCA{i+1}'][~y].var()) / 2)\n",
1335
+ " print(f\"Cohen's d for PCA{i+1}: {cohens_d:.4f}\")\n",
1336
+ "\n",
1337
+ "# 5. Trajectory length comparison\n",
1338
+ "trajectory_lengths = np.array([np.sum(np.sqrt(np.sum(np.diff(traj, axis=0)**2, axis=1))) for traj in trajectories_pca])\n",
1339
+ "t_stat, p_value = stats.ttest_ind(trajectory_lengths[y], trajectory_lengths[~y])\n",
1340
+ "print(f\"T-test for trajectory lengths: t-statistic = {t_stat:.4f}, p-value = {p_value:.4f}\")"
1341
+ ],
1342
+ "metadata": {
1343
+ "id": "rqPocLPzDFiM"
1344
+ },
1345
+ "execution_count": null,
1346
+ "outputs": []
1347
+ },
1348
+ {
1349
+ "cell_type": "code",
1350
+ "source": [
1351
+ "# Correlation between trajectory complexity and validity\n",
1352
+ "# Analyze trajectory complexity\n",
1353
+ "def trajectory_complexity(traj):\n",
1354
+ " return np.sum(np.linalg.norm(np.diff(traj, axis=0), axis=1))\n",
1355
+ "\n",
1356
+ "complexities = [trajectory_complexity(traj) for traj in reduced_trajectories]\n",
1357
+ "df['complexity'] = complexities\n",
1358
+ "complexity_correlation = stats.pointbiserialr(df['is_valid'], df['complexity'])\n",
1359
+ "print(f\"Correlation between trajectory complexity and validity: r = {complexity_correlation.correlation:.4f}, p = {complexity_correlation.pvalue:.4f}\")"
1360
+ ],
1361
+ "metadata": {
1362
+ "id": "csICTST5BcS5"
1363
+ },
1364
+ "execution_count": null,
1365
+ "outputs": []
1366
+ },
1367
+ {
1368
+ "cell_type": "markdown",
1369
+ "source": [
1370
+ "## Canonical transformations"
1371
+ ],
1372
+ "metadata": {
1373
+ "id": "c0kKU3xdVpMf"
1374
+ }
1375
+ },
1376
+ {
1377
+ "cell_type": "code",
1378
+ "source": [
1379
+ "def hamiltonian(state, t, k):\n",
1380
+ " \"\"\"Simple harmonic oscillator Hamiltonian\"\"\"\n",
1381
+ " q, p = state\n",
1382
+ " return p**2 / 2 + k * q**2 / 2\n",
1383
+ "\n",
1384
+ "def hamilton_equations(state, t, k):\n",
1385
+ " \"\"\"Hamilton's equations for simple harmonic oscillator\"\"\"\n",
1386
+ " q, p = state\n",
1387
+ " dqdt = p\n",
1388
+ " dpdt = -k * q\n",
1389
+ " return [dqdt, dpdt]\n",
1390
+ "\n",
1391
+ "def canonical_transform_to_action_angle(q, p, k):\n",
1392
+ " \"\"\"Transform from (q,p) to action-angle variables (I, theta)\"\"\"\n",
1393
+ " I = (p**2 + k * q**2) / (2 * k)\n",
1394
+ " theta = np.arctan2(np.sqrt(k) * q, p)\n",
1395
+ " return I, theta\n",
1396
+ "\n",
1397
+ "def inverse_canonical_transform(I, theta, k):\n",
1398
+ " \"\"\"Transform from action-angle variables (I, theta) back to (q,p)\"\"\"\n",
1399
+ " q = np.sqrt(2 * I / k) * np.sin(theta)\n",
1400
+ " p = np.sqrt(2 * I * k) * np.cos(theta)\n",
1401
+ " return q, p\n",
1402
+ "\n",
1403
+ "# Parameters\n",
1404
+ "k = 1.0 # Spring constant\n",
1405
+ "t = np.linspace(0, 10, 100)\n",
1406
+ "\n",
1407
+ "# Apply canonical transformation to our trajectories\n",
1408
+ "action_angle_trajectories = []\n",
1409
+ "for traj in trajectories_pca:\n",
1410
+ " q, p = traj[:, 0], traj[:, 1] # Assuming first two PCs represent position and momentum\n",
1411
+ " I, theta = canonical_transform_to_action_angle(q, p, k)\n",
1412
+ " action_angle_trajectories.append(np.column_stack((I, theta)))\n",
1413
+ "\n",
1414
+ "\n",
1415
+ "# Analysis\n",
1416
+ "action_means_valid = [np.mean(traj[:, 0]) for traj, valid in zip(action_angle_trajectories, df['is_valid'].tolist()) if valid]\n",
1417
+ "action_means_nonvalid = [np.mean(traj[:, 0]) for traj, valid in zip(action_angle_trajectories, df['is_valid'].tolist()) if not valid]\n",
1418
+ "angle_ranges_valid = [np.ptp(traj[:, 1]) for traj, valid in zip(action_angle_trajectories, df['is_valid'].tolist()) if valid]\n",
1419
+ "angle_ranges_nonvalid = [np.ptp(traj[:, 1]) for traj, valid in zip(action_angle_trajectories, df['is_valid'].tolist()) if not valid]\n",
1420
+ "\n",
1421
+ "print(f\"Mean action for valid chains: {np.mean(action_means_valid):.4f}\")\n",
1422
+ "print(f\"Mean action for non-valid chains: {np.mean(action_means_nonvalid):.4f}\")\n",
1423
+ "print(f\"Mean angle range for valid chains: {np.mean(angle_ranges_valid):.4f}\")\n",
1424
+ "print(f\"Mean angle range for non-valid chains: {np.mean(angle_ranges_nonvalid):.4f}\")\n",
1425
+ "\n",
1426
+ "# Statistical tests\n",
1427
+ "from scipy import stats\n",
1428
+ "\n",
1429
+ "t_stat, p_value = stats.ttest_ind(action_means_valid, action_means_nonvalid)\n",
1430
+ "print(f\"T-test for action means: t-statistic = {t_stat:.4f}, p-value = {p_value:.4f}\")\n",
1431
+ "\n",
1432
+ "t_stat, p_value = stats.ttest_ind(angle_ranges_valid, angle_ranges_nonvalid)\n",
1433
+ "print(f\"T-test for angle ranges: t-statistic = {t_stat:.4f}, p-value = {p_value:.4f}\")\n",
1434
+ "\n",
1435
+ "# Classify trajectories based on action and angle properties\n",
1436
+ "def classify_trajectory(action, angle_range, valid):\n",
1437
+ " high_action = np.mean(action_means_valid if valid else action_means_nonvalid) + np.std(action_means_valid if valid else action_means_nonvalid)\n",
1438
+ " low_action = np.mean(action_means_valid if valid else action_means_nonvalid) - np.std(action_means_valid if valid else action_means_nonvalid)\n",
1439
+ " high_angle_range = np.mean(angle_ranges_valid if valid else angle_ranges_nonvalid) + np.std(angle_ranges_valid if valid else angle_ranges_nonvalid)\n",
1440
+ "\n",
1441
+ " if action > high_action and angle_range > high_angle_range:\n",
1442
+ " return \"High energy, complex reasoning\"\n",
1443
+ " elif action < low_action and angle_range > high_angle_range:\n",
1444
+ " return \"Low energy, exploratory reasoning\"\n",
1445
+ " elif action > high_action and angle_range <= high_angle_range:\n",
1446
+ " return \"High energy, focused reasoning\"\n",
1447
+ " elif action < low_action and angle_range <= high_angle_range:\n",
1448
+ " return \"Low energy, simple reasoning\"\n",
1449
+ " else:\n",
1450
+ " return \"Moderate reasoning\""
1451
+ ],
1452
+ "metadata": {
1453
+ "id": "Pm52IjYTXMMH"
1454
+ },
1455
+ "execution_count": null,
1456
+ "outputs": []
1457
+ },
1458
+ {
1459
+ "cell_type": "code",
1460
+ "source": [
1461
+ "# Plotting\n",
1462
+ "fig = plt.figure(figsize=(15, 5))\n",
1463
+ "\n",
1464
+ "# Original space\n",
1465
+ "ax1 = fig.add_subplot(131)\n",
1466
+ "for traj, valid in zip(trajectories_pca[:10], df['is_valid'].tolist()[:10]): # Plot first 10 for clarity\n",
1467
+ " color = 'green' if valid else 'red'\n",
1468
+ " ax1.plot(traj[:, 0], traj[:, 1], color=color, alpha=0.7)\n",
1469
+ "ax1.set_xlabel('PC1 (q)', fontsize=12)\n",
1470
+ "ax1.set_ylabel('PC2 (p)', fontsize=12)\n",
1471
+ "ax1.set_title('Original Phase Space', fontsize=14)\n",
1472
+ "ax1.legend([valid_handle, invalid_handle], ['Valid', 'Invalid'], loc='upper right', fontsize=12)\n",
1473
+ "\n",
1474
+ "# Action-Angle space\n",
1475
+ "ax2 = fig.add_subplot(132)\n",
1476
+ "for traj, valid in zip(action_angle_trajectories[:10], df['is_valid'].tolist()[:10]):\n",
1477
+ " color = 'green' if valid else 'red'\n",
1478
+ " ax2.plot(traj[:, 0], traj[:, 1], color=color, alpha=0.7)\n",
1479
+ "ax2.set_xlabel('Action (I)', fontsize=12)\n",
1480
+ "ax2.set_ylabel('Angle (theta)', fontsize=12)\n",
1481
+ "ax2.set_title('Action-Angle Space', fontsize=14)\n",
1482
+ "ax2.legend([valid_handle, invalid_handle], ['Valid', 'Invalid'], loc='upper right', fontsize=12)\n",
1483
+ "\n",
1484
+ "# 3D visualization\n",
1485
+ "ax3 = fig.add_subplot(133, projection='3d')\n",
1486
+ "for traj, valid in zip(action_angle_trajectories[:10], df['is_valid'].tolist()[:10]):\n",
1487
+ " color = 'green' if valid else 'red'\n",
1488
+ " ax3.plot(traj[:, 0], np.cos(traj[:, 1]), np.sin(traj[:, 1]), color=color, alpha=0.7)\n",
1489
+ "ax3.set_xlabel('Action (I)', fontsize=12)\n",
1490
+ "ax3.set_ylabel('cos(theta)', fontsize=12)\n",
1491
+ "ax3.set_zlabel('sin(theta)', fontsize=12)\n",
1492
+ "ax3.set_title('3D Action-Angle Space', fontsize=14)\n",
1493
+ "ax3.legend([valid_handle, invalid_handle], ['Valid', 'Invalid'], loc='upper right', fontsize=12)\n",
1494
+ "\n",
1495
+ "plt.tight_layout()\n",
1496
+ "plt.savefig('canonical_transformation_analysis_with_validity.png', dpi=300, bbox_inches='tight')\n",
1497
+ "plt.show()"
1498
+ ],
1499
+ "metadata": {
1500
+ "id": "YlzvprO0ZBo1"
1501
+ },
1502
+ "execution_count": null,
1503
+ "outputs": []
1504
+ },
1505
+ {
1506
+ "cell_type": "markdown",
1507
+ "source": [
1508
+ "## Conservation laws"
1509
+ ],
1510
+ "metadata": {
1511
+ "id": "b-FE7nQWW1Oe"
1512
+ }
1513
+ },
1514
+ {
1515
+ "cell_type": "code",
1516
+ "source": [
1517
+ "def calculate_hamiltonian(q, p):\n",
1518
+ " \"\"\"Simple Hamiltonian function\"\"\"\n",
1519
+ " return 0.5 * (q**2 + p**2)\n",
1520
+ "\n",
1521
+ "def calculate_angular_momentum(q, p):\n",
1522
+ " \"\"\"Angular momentum-like quantity\"\"\"\n",
1523
+ " return q * p\n",
1524
+ "\n",
1525
+ "def calculate_energy_like_quantity(q, p):\n",
1526
+ " \"\"\"Energy-like conserved quantity\"\"\"\n",
1527
+ " return q**2 - p**2\n",
1528
+ "\n",
1529
+ "def analyze_conservation(trajectories, quantity_func, quantity_name):\n",
1530
+ " conserved_scores = []\n",
1531
+ " for traj in trajectories:\n",
1532
+ " q_start, q_end = traj[:, 0]\n",
1533
+ " p_start, p_end = traj[:, 1]\n",
1534
+ " quantity_start = quantity_func(q_start, p_start)\n",
1535
+ " quantity_end = quantity_func(q_end, p_end)\n",
1536
+ " change = abs(quantity_end - quantity_start)\n",
1537
+ " conserved_scores.append(change)\n",
1538
+ " return conserved_scores\n",
1539
+ "\n",
1540
+ "# Analyze conservation for different quantities\n",
1541
+ "hamiltonian_scores = analyze_conservation(trajectories_2d, calculate_hamiltonian, \"Hamiltonian\")\n",
1542
+ "angular_momentum_scores = analyze_conservation(trajectories_2d, calculate_angular_momentum, \"Angular Momentum\")\n",
1543
+ "energy_scores = analyze_conservation(trajectories_2d, calculate_energy_like_quantity, \"Energy-like Quantity\")\n",
1544
+ "\n",
1545
+ "# Print some statistics\n",
1546
+ "print(\"Hamiltonian changes - Mean: {:.4f}, Std: {:.4f}\".format(np.mean(hamiltonian_scores), np.std(hamiltonian_scores)))\n",
1547
+ "print(\"Angular Momentum changes - Mean: {:.4f}, Std: {:.4f}\".format(np.mean(angular_momentum_scores), np.std(angular_momentum_scores)))\n",
1548
+ "print(\"Energy-like Quantity changes - Mean: {:.4f}, Std: {:.4f}\".format(np.mean(energy_scores), np.std(energy_scores)))"
1549
+ ],
1550
+ "metadata": {
1551
+ "id": "t_aym0wlWBpg"
1552
+ },
1553
+ "execution_count": null,
1554
+ "outputs": []
1555
+ },
1556
+ {
1557
+ "cell_type": "code",
1558
+ "source": [
1559
+ "# Visualize conservation of quantities\n",
1560
+ "plt.figure(figsize=(15, 5))\n",
1561
+ "\n",
1562
+ "plt.subplot(131)\n",
1563
+ "plt.hist(hamiltonian_scores, bins=20, color='blue', alpha=0.7)\n",
1564
+ "plt.title(\"Conservation of Hamiltonian\", fontsize=16)\n",
1565
+ "plt.xlabel(\"Standard Error\", fontsize=14)\n",
1566
+ "plt.ylabel(\"Frequency\", fontsize=14)\n",
1567
+ "\n",
1568
+ "plt.subplot(132)\n",
1569
+ "plt.hist(angular_momentum_scores, bins=20, color='green', alpha=0.7)\n",
1570
+ "plt.title(\"Conservation of Angular Momentum\", fontsize=16)\n",
1571
+ "plt.xlabel(\"Standard Error\", fontsize=14)\n",
1572
+ "plt.ylabel(\"Frequency\", fontsize=14)\n",
1573
+ "\n",
1574
+ "plt.subplot(133)\n",
1575
+ "plt.hist(energy_scores, bins=20, color='red', alpha=0.7)\n",
1576
+ "plt.title(\"Conservation of Energy-like Quantity\", fontsize=16)\n",
1577
+ "plt.xlabel(\"Standard Error\", fontsize=14)\n",
1578
+ "plt.ylabel(\"Frequency\", fontsize=14)\n",
1579
+ "\n",
1580
+ "plt.tight_layout()\n",
1581
+ "plt.savefig('conservation_laws_analysis.png', dpi=300, bbox_inches='tight')\n",
1582
+ "plt.show()"
1583
+ ],
1584
+ "metadata": {
1585
+ "id": "zOFQfeap55P7"
1586
+ },
1587
+ "execution_count": null,
1588
+ "outputs": []
1589
+ },
1590
+ {
1591
+ "cell_type": "code",
1592
+ "source": [
1593
+ "# Calculate the overall range for x-axis\n",
1594
+ "all_scores = np.concatenate([hamiltonian_scores, angular_momentum_scores, energy_scores])\n",
1595
+ "min_score = np.min(all_scores)\n",
1596
+ "max_score = np.max(all_scores)\n",
1597
+ "\n",
1598
+ "# Create bins that cover the entire range\n",
1599
+ "bins = np.linspace(min_score, max_score, 21) # 20 bins\n",
1600
+ "\n",
1601
+ "# Compute histograms\n",
1602
+ "h_hist, _ = np.histogram(hamiltonian_scores, bins=bins)\n",
1603
+ "a_hist, _ = np.histogram(angular_momentum_scores, bins=bins)\n",
1604
+ "e_hist, _ = np.histogram(energy_scores, bins=bins)\n",
1605
+ "\n",
1606
+ "# Find the maximum frequency across all histograms\n",
1607
+ "max_freq = max(np.max(h_hist), np.max(a_hist), np.max(e_hist))\n",
1608
+ "\n",
1609
+ "plt.figure(figsize=(15, 5))\n",
1610
+ "\n",
1611
+ "plt.subplot(131)\n",
1612
+ "plt.hist(hamiltonian_scores, bins=bins, color='blue', alpha=0.7)\n",
1613
+ "plt.title(\"Conservation of Hamiltonian\", fontsize=16)\n",
1614
+ "plt.xlabel(\"Standard Error\", fontsize=14)\n",
1615
+ "plt.ylabel(\"Frequency\", fontsize=14)\n",
1616
+ "plt.xlim(min_score, max_score)\n",
1617
+ "plt.ylim(0, max_freq)\n",
1618
+ "\n",
1619
+ "plt.subplot(132)\n",
1620
+ "plt.hist(angular_momentum_scores, bins=bins, color='green', alpha=0.7)\n",
1621
+ "plt.title(\"Conservation of Angular Momentum\", fontsize=16)\n",
1622
+ "plt.xlabel(\"Standard Error\", fontsize=14)\n",
1623
+ "plt.ylabel(\"Frequency\", fontsize=14)\n",
1624
+ "plt.xlim(min_score, max_score)\n",
1625
+ "plt.ylim(0, max_freq)\n",
1626
+ "\n",
1627
+ "plt.subplot(133)\n",
1628
+ "plt.hist(energy_scores, bins=bins, color='red', alpha=0.7)\n",
1629
+ "plt.title(\"Conservation of Energy-like Quantity\", fontsize=16)\n",
1630
+ "plt.xlabel(\"Standard Error\", fontsize=14)\n",
1631
+ "plt.ylabel(\"Frequency\", fontsize=14)\n",
1632
+ "plt.xlim(min_score, max_score)\n",
1633
+ "plt.ylim(0, max_freq)\n",
1634
+ "\n",
1635
+ "plt.tight_layout()\n",
1636
+ "plt.savefig('conservation_laws_analysis_same_scales.png', dpi=300, bbox_inches='tight')\n",
1637
+ "plt.show()"
1638
+ ],
1639
+ "metadata": {
1640
+ "id": "9FYy8-nIZwsy"
1641
+ },
1642
+ "execution_count": null,
1643
+ "outputs": []
1644
+ },
1645
+ {
1646
+ "cell_type": "code",
1647
+ "source": [
1648
+ "def calculate_trajectory_entropy(trajectory):\n",
1649
+ " \"\"\"Calculate the entropy of a trajectory.\"\"\"\n",
1650
+ " # Discretize the trajectory into bins\n",
1651
+ " hist, _ = np.histogram(trajectory, bins=20, density=True)\n",
1652
+ " return entropy(hist)\n",
1653
+ "\n",
1654
+ "def calculate_free_energy(trajectory, temperature=1.0):\n",
1655
+ " \"\"\"Calculate a free energy analog for a trajectory.\"\"\"\n",
1656
+ " # Assume energy is proportional to the squared distance from the origin\n",
1657
+ " energy = np.sum(trajectory**2, axis=1)\n",
1658
+ " entropy = calculate_trajectory_entropy(energy)\n",
1659
+ " return np.mean(energy) - temperature * entropy\n",
1660
+ "\n",
1661
+ "# Apply to all trajectories\n",
1662
+ "trajectory_entropies = [calculate_trajectory_entropy(traj) for traj in trajectories_2d]\n",
1663
+ "free_energies = [calculate_free_energy(traj) for traj in trajectories_2d]\n",
1664
+ "\n",
1665
+ "# Analyze the results\n",
1666
+ "print(\"Mean trajectory entropy:\", np.mean(trajectory_entropies))\n",
1667
+ "print(\"Mean free energy:\", np.mean(free_energies))\n",
1668
+ "\n",
1669
+ "# Visualize the results\n",
1670
+ "plt.figure(figsize=(12, 5))\n",
1671
+ "plt.subplot(121)\n",
1672
+ "plt.hist(trajectory_entropies, bins=20)\n",
1673
+ "plt.title(\"Distribution of Trajectory Entropies\", fontsize=16)\n",
1674
+ "plt.xlabel(\"Entropy\", fontsize=14)\n",
1675
+ "plt.ylabel(\"Frequency\", fontsize=14)\n",
1676
+ "\n",
1677
+ "plt.subplot(122)\n",
1678
+ "plt.hist(free_energies, bins=20)\n",
1679
+ "plt.title(\"Distribution of Free Energies\", fontsize=16)\n",
1680
+ "plt.xlabel(\"Free Energy\", fontsize=14)\n",
1681
+ "plt.ylabel(\"Frequency\", fontsize=14)\n",
1682
+ "plt.tight_layout()\n",
1683
+ "plt.show()"
1684
+ ],
1685
+ "metadata": {
1686
+ "id": "Ws8Ugh7kbj9T"
1687
+ },
1688
+ "execution_count": null,
1689
+ "outputs": []
1690
+ },
1691
+ {
1692
+ "cell_type": "code",
1693
+ "source": [
1694
+ "def measure_computation_time(trajectories, num_samples):\n",
1695
+ " \"\"\"Measure computation time for different numbers of trajectories.\"\"\"\n",
1696
+ " times = []\n",
1697
+ " sample_sizes = range(100, num_samples, 100)\n",
1698
+ "\n",
1699
+ " for size in sample_sizes:\n",
1700
+ " start_time = time.time()\n",
1701
+ " _ = [analyze_trajectory(traj) for traj in trajectories[:size]]\n",
1702
+ " end_time = time.time()\n",
1703
+ " times.append(end_time - start_time)\n",
1704
+ "\n",
1705
+ " return sample_sizes, times\n",
1706
+ "\n",
1707
+ "def analyze_trajectory(trajectory):\n",
1708
+ " \"\"\"Placeholder for your trajectory analysis function.\"\"\"\n",
1709
+ " # Replace this with your actual analysis\n",
1710
+ " return calculate_hamiltonian(trajectory[:, 0], trajectory[:, 1])\n",
1711
+ "\n",
1712
+ "# Measure computation time\n",
1713
+ "sample_sizes, computation_times = measure_computation_time(trajectories_2d, len(trajectories_2d))\n"
1714
+ ],
1715
+ "metadata": {
1716
+ "id": "c4hO5bUXb_VP"
1717
+ },
1718
+ "execution_count": null,
1719
+ "outputs": []
1720
+ },
1721
+ {
1722
+ "cell_type": "code",
1723
+ "source": [
1724
+ "# Plot the results\n",
1725
+ "plt.figure(figsize=(10, 6))\n",
1726
+ "plt.plot(sample_sizes, computation_times, 'b-')\n",
1727
+ "plt.title(\"Computational Complexity\", fontsize=16)\n",
1728
+ "plt.xlabel(\"Number of Trajectories\", fontsize=14)\n",
1729
+ "plt.ylabel(\"Computation Time (seconds)\", fontsize=14)\n",
1730
+ "plt.grid(True)\n",
1731
+ "plt.show()"
1732
+ ],
1733
+ "metadata": {
1734
+ "id": "OWw-V4apZX48"
1735
+ },
1736
+ "execution_count": null,
1737
+ "outputs": []
1738
+ },
1739
+ {
1740
+ "cell_type": "code",
1741
+ "source": [
1742
+ "# Estimate complexity\n",
1743
+ "def complexity_function(x, a, b):\n",
1744
+ " return a * x**b\n",
1745
+ "\n",
1746
+ "popt, _ = curve_fit(complexity_function, sample_sizes, computation_times)\n",
1747
+ "\n",
1748
+ "print(f\"Estimated complexity: O(n^{popt[1]:.2f})\")"
1749
+ ],
1750
+ "metadata": {
1751
+ "id": "Pady9Cj8ZIdz"
1752
+ },
1753
+ "execution_count": null,
1754
+ "outputs": []
1755
+ },
1756
+ {
1757
+ "cell_type": "code",
1758
+ "source": [
1759
+ "def classify_trajectory(trajectory):\n",
1760
+ " \"\"\"Classify a trajectory as valid or invalid based on Hamiltonian conservation.\"\"\"\n",
1761
+ " hamiltonian_change = np.abs(calculate_hamiltonian(trajectory[0, 0], trajectory[0, 1]) -\n",
1762
+ " calculate_hamiltonian(trajectory[-1, 0], trajectory[-1, 1]))\n",
1763
+ " return hamiltonian_change < 0.5 # Threshold for classification\n",
1764
+ "\n",
1765
+ "# Split the data\n",
1766
+ "X_train, X_test, y_train, y_test = train_test_split(trajectories_2d, df['is_valid'], test_size=0.2, random_state=42)\n",
1767
+ "\n",
1768
+ "# Classify test set\n",
1769
+ "y_pred = [classify_trajectory(traj) for traj in X_test]\n",
1770
+ "\n",
1771
+ "# Analyze errors\n",
1772
+ "conf_matrix = confusion_matrix(y_test, y_pred)\n",
1773
+ "class_report = classification_report(y_test, y_pred)\n",
1774
+ "\n",
1775
+ "print(\"Confusion Matrix:\")\n",
1776
+ "print(conf_matrix)\n",
1777
+ "print(\"\\nClassification Report:\")\n",
1778
+ "print(class_report)\n",
1779
+ "\n",
1780
+ "# Analyze misclassified trajectories\n",
1781
+ "misclassified = X_test[y_test != y_pred]\n",
1782
+ "misclassified_labels = y_test[y_test != y_pred]\n",
1783
+ "\n",
1784
+ "print(\"\\nAnalysis of Misclassified Trajectories:\")\n",
1785
+ "for i, (traj, true_label) in enumerate(zip(misclassified, misclassified_labels)):\n",
1786
+ " hamiltonian_change = np.abs(calculate_hamiltonian(traj[0, 0], traj[0, 1]) -\n",
1787
+ " calculate_hamiltonian(traj[-1, 0], traj[-1, 1]))\n",
1788
+ " print(f\"Trajectory {i}:\")\n",
1789
+ " print(f\" True label: {'Valid' if true_label else 'Invalid'}\")\n",
1790
+ " print(f\" Predicted: {'Valid' if classify_trajectory(traj) else 'Invalid'}\")\n",
1791
+ " print(f\" Hamiltonian change: {hamiltonian_change:.4f}\")\n",
1792
+ " print(f\" Start point: {traj[0]}\")\n",
1793
+ " print(f\" End point: {traj[-1]}\")\n",
1794
+ " print()\n",
1795
+ "\n",
1796
+ "# Visualize some misclassified trajectories\n",
1797
+ "plt.figure(figsize=(15, 5))\n",
1798
+ "for i in range(3):\n",
1799
+ " plt.subplot(1, 3, i+1)\n",
1800
+ " plt.plot(misclassified[i][:, 0], misclassified[i][:, 1], 'r-')\n",
1801
+ " plt.scatter(misclassified[i][0, 0], misclassified[i][0, 1], c='g', label='Start')\n",
1802
+ " plt.scatter(misclassified[i][-1, 0], misclassified[i][-1, 1], c='b', label='End')\n",
1803
+ " plt.title(f\"Misclassified Trajectory {i+1}\", fontsize=16)\n",
1804
+ " plt.xlabel(\"PC1\", fontsize=14)\n",
1805
+ " plt.ylabel(\"PC2\", fontsize=14)\n",
1806
+ " plt.legend()\n",
1807
+ "plt.tight_layout()\n",
1808
+ "plt.show()"
1809
+ ],
1810
+ "metadata": {
1811
+ "id": "p9PhYaNpcJJd"
1812
+ },
1813
+ "execution_count": null,
1814
+ "outputs": []
1815
+ }
1816
+ ]
1817
+ }