{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "code", "source": [ "\"\"\"\n", "This script implements corresponds to the experiments conducted for\n", "weitting the paper \"Optimizing AI Reasoning: A Hamiltonian Dynamics Approach to\n", "Multi-Hop Question Answering\".\n", "\n", "Author: Javier Marín\n", "Email: javier@jmarin.info\n", "Version: 1.0.0\n", "Date: October 65, 2024\n", "\n", "License: MIT License\n", "\n", "Copyright (c) 2024 Javier Marín\n", "\n", "Permission is hereby granted, free of charge, to any person obtaining a copy\n", "of this software and associated documentation files (the \"Software\"), to deal\n", "in the Software without restriction, including without limitation the rights\n", "to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n", "copies of the Software, and to permit persons to whom the Software is\n", "furnished to do so, subject to the following conditions:\n", "\n", "The above copyright notice and this permission notice shall be included in all\n", "copies or substantial portions of the Software.\n", "\n", "THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n", "IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n", "FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n", "AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n", "LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n", "OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n", "SOFTWARE.\n", "\n", "Dependencies:\n", "- Python 3.8+\n", "- NumPy\n", "- Pandas\n", "- PyTorch\n", "- Transformers\n", "- Scikit-learn\n", "- SciPy\n", "- Statsmodels\n", "- Matplotlib\n", "- Seaborn\n", "\n", "For a full list of dependencies and their versions, see requirements.txt\n", "\"\"\"" ], "metadata": { "id": "T-57ivc-aTrA" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Imports" ], "metadata": { "id": "QUcpzyBmWpLv" } }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "l2rfFoVtIL6_" }, "outputs": [], "source": [ "# Standard library imports\n", "import os\n", "import re\n", "import time\n", "\n", "# Third-party imports\n", "import numpy as np\n", "import pandas as pd\n", "import torch\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "from mpl_toolkits.mplot3d import Axes3D\n", "\n", "from transformers import AutoTokenizer, AutoModel\n", "from statsmodels.multivariate.manova import MANOVA\n", "from scipy import stats\n", "from scipy.optimize import curve_fit\n", "from scipy.integrate import odeint\n", "from sklearn import (\n", " metrics,\n", " model_selection,\n", " cluster,\n", " decomposition,\n", " feature_extraction,\n", " linear_model\n", ")\n", "\n", "# Visualization settings\n", "sns.set_theme(style=\"whitegrid\", context=\"paper\")\n", "plt.rcParams['font.family'] = 'serif'\n", "plt.rcParams['font.serif'] = ['Times New Roman'] + plt.rcParams['font.serif']" ] }, { "cell_type": "markdown", "source": [ "## Load BERT pretrained model" ], "metadata": { "id": "4nApCVrOWkR3" } }, { "cell_type": "code", "source": [ "# Load pre-trained model and tokenizer\n", "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n", "model = AutoModel.from_pretrained(\"bert-base-uncased\")" ], "metadata": { "id": "hT2I1H8BIOp_" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Load data" ], "metadata": { "id": "9KKw24bCWgWj" } }, { "cell_type": "code", "source": [ "# Load the OBQA dataset\n", "df = pd.read_csv(\"obqa_chains.csv\", sep=\";\")\n", "\n", "# Ensure necessary columns exist\n", "required_columns = ['QID', 'Chain#', 'Question', 'Answer', 'Fact1', 'Fact2', 'Turk']\n", "missing_columns = [col for col in required_columns if col not in df.columns]\n", "if missing_columns:\n", " raise ValueError(f\"Missing required columns: {missing_columns}\")\n", "\n", "# Preprocess the data\n", "df['Question'] = df['Question'] + \" \" + df['Answer'] # Combine question and answer\n", "df['is_valid'] = df['Turk'].str.contains('yes', case=False, na=False)" ], "metadata": { "id": "g2f-T9koIOjH" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Model embeddings" ], "metadata": { "id": "XdN9XTGOWdsh" } }, { "cell_type": "code", "source": [ "def get_bert_embedding(text):\n", " \"\"\"Get BERT embedding for a given text.\"\"\"\n", " inputs = tokenizer(text, return_tensors=\"pt\", padding=True, truncation=True, max_length=512)\n", " with torch.no_grad():\n", " outputs = model(**inputs)\n", " return outputs.last_hidden_state.mean(dim=1).squeeze().numpy()\n", "\n", "def refined_hamiltonian_energy(chain):\n", " emb1 = get_bert_embedding(chain['Fact1'])\n", " emb2 = get_bert_embedding(chain['Fact2'])\n", " emb_q = get_bert_embedding(chain['Question'])\n", "\n", " # Refined kinetic term: measure of change between facts\n", " T = np.linalg.norm(emb2 - emb1)\n", "\n", " # Refined potential term: measure of relevance to question\n", " V = (np.dot(emb1, emb_q) + np.dot(emb2, emb_q)) / 2\n", "\n", " # Total \"Hamiltonian\" energy: balance between change and relevance\n", " H = T - V\n", "\n", " return H, T, V\n", "\n", "\n", "# Analyze energy conservation\n", "def energy_conservation_score(chain):\n", " _, T, V = refined_hamiltonian_energy(chain)\n", " # Measure how balanced T and V are\n", " return 1 / (1 + abs(T - V)) # Now always between 0 and 1, 1 being perfect balance\n", "\n", "\n", "\n", "# Calculate refined energies and scores\n", "df['H_energy'], df['T_energy'], df['V_energy'] = zip(*df.apply(refined_hamiltonian_energy, axis=1))\n", "df['energy_conservation'] = df.apply(energy_conservation_score, axis=1)" ], "metadata": { "id": "3q4EMfekIOZ_" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Hamiltonian systems" ], "metadata": { "id": "pvQgqhW2Wage" } }, { "cell_type": "code", "source": [ "def get_trajectory(row):\n", " # Ensure we're working with strings\n", " chain = [str(row['Fact1']), str(row['Fact2'])]\n", " embeddings = [get_bert_embedding(sentence) for sentence in chain]\n", " return np.array(embeddings)\n", "\n", "def refined_hamiltonian_energy(chain):\n", " emb1 = get_bert_embedding(chain['Fact1'])\n", " emb2 = get_bert_embedding(chain['Fact2'])\n", "\n", " # Refined kinetic term: measure of change between facts\n", " T = np.linalg.norm(emb2 - emb1)\n", "\n", " # Refined potential term: measure of relevance to facts\n", " V = (np.linalg.norm(emb1) + np.linalg.norm(emb2)) / 2\n", "\n", " # Total \"Hamiltonian\" energy: balance between change and relevance\n", " H = T - V\n", "\n", " return H, T, V\n", "\n", "\n", "def compute_trajectory_energy(trajectory):\n", " return refined_hamiltonian_energy({'Fact1': str(trajectory[0]), 'Fact2': str(trajectory[1])})[0]\n", "\n", "\n", "# Compute trajectories for all chains\n", "trajectories = df.apply(get_trajectory, axis=1)\n", "\n", "# Compute energies for trajectories\n", "trajectory_energies = trajectories.apply(compute_trajectory_energy)\n" ], "metadata": { "id": "yveIXutUX3ub" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Use PCA to reduce dimensionality for visualization\n", "pca = PCA(n_components=3)\n", "all_points = np.vstack(trajectories.values)\n", "pca_result = pca.fit_transform(all_points)\n", "\n", "trajectories_3d = trajectories.apply(lambda t: pca.transform(t))\n", "\n", "\n", "# Analyze trajectory properties\n", "def trajectory_length(traj):\n", " return np.sum(np.sqrt(np.sum(np.diff(traj, axis=0)**2, axis=1)))\n", "\n", "def trajectory_smoothness(traj):\n", " first = abs(np.diff(traj[0], axis=0))[0]\n", " second = abs(np.diff(traj[1], axis=0))[0]\n", " return (first + second)/2\n", "\n", "traj_properties = pd.DataFrame({\n", " 'length': trajectories_3d.apply(trajectory_length),\n", " 'smoothness': trajectories_3d.apply(trajectory_smoothness),\n", " 'is_valid': df['is_valid']\n", "})\n" ], "metadata": { "id": "qFF7_0TD6JRO" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Create the main figure and grid for subplots\n", "fig, axs = plt.subplots(2, 2, figsize=(15, 12))\n", "fig.suptitle(\"Refined Hamiltonian-Inspired Energy Analysis of Reasoning Chains\", fontsize=16)\n", "\n", "# Distribution of Hamiltonian Energy\n", "sns.histplot(data=df, x='H_energy', ax=axs[0, 0], kde=True, color='blue', bins=50)\n", "axs[0, 0].set_title(\"Distribution of Refined Hamiltonian Energy\")\n", "axs[0, 0].set_xlabel(\"Hamiltonian Energy\")\n", "axs[0, 0].set_ylabel(\"Count\")\n", "\n", "# Kinetic vs Potential Energy\n", "scatter = axs[0, 1].scatter(df['T_energy'], df['V_energy'], c=df['H_energy'], cmap='viridis', s=5, alpha=0.6)\n", "axs[0, 1].set_title(\"Refined Kinetic vs Potential Energy\")\n", "axs[0, 1].set_xlabel(\"Kinetic Energy (T)\")\n", "axs[0, 1].set_ylabel(\"Potential Energy (V)\")\n", "plt.colorbar(scatter, ax=axs[0, 1], label=\"Hamiltonian Energy\")\n", "\n", "# Hamiltonian Energy: Valid vs Invalid Chains\n", "valid_chains = df[df['is_valid']]\n", "invalid_chains = df[~df['is_valid']]\n", "sns.histplot(data=valid_chains, x='H_energy', ax=axs[1, 0], kde=True, color='green', label='Valid Chains', bins=50, alpha=0.6)\n", "sns.histplot(data=invalid_chains, x='H_energy', ax=axs[1, 0], kde=True, color='red', label='Invalid Chains', bins=50, alpha=0.6)\n", "axs[1, 0].set_title(\"Refined Hamiltonian Energy: Valid vs Invalid Chains\")\n", "axs[1, 0].set_xlabel(\"Hamiltonian Energy\")\n", "axs[1, 0].set_ylabel(\"Count\")\n", "axs[1, 0].legend()\n", "\n", "# Distribution of Energy Conservation Scores\n", "sns.histplot(data=df, x='energy_conservation', ax=axs[1, 1], kde=True, color='orange', bins=50)\n", "axs[1, 1].set_title(\"Distribution of Refined Energy Conservation Scores\")\n", "axs[1, 1].set_xlabel(\"Energy Conservation Score\")\n", "axs[1, 1].set_ylabel(\"Count\")\n", "\n", "# Adjust layout and display\n", "plt.tight_layout()\n", "plt.subplots_adjust(top=0.93) # Adjust for main title\n", "plt.savefig('refined_hamiltonian_analysis.png', dpi=300, bbox_inches='tight')\n", "plt.show()" ], "metadata": { "id": "kqfbA7w3NuPM" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Calculate direction vectors\n", "def calculate_direction(trajectory):\n", " return trajectory[1] - trajectory[0]\n", "\n", "direction_vectors = np.array([calculate_direction(traj) for traj in trajectories_3d])\n", "\n", "# Calculate magnitude and angle of direction vectors\n", "magnitudes = np.linalg.norm(direction_vectors, axis=1)\n", "angles = np.arctan2(direction_vectors[:, 1], direction_vectors[:, 0])\n", "\n", "# Add these to the dataframe\n", "df['trajectory_magnitude'] = magnitudes\n", "df['trajectory_angle'] = angles\n", "\n", "# Visualize magnitude distribution\n", "plt.figure(figsize=(12, 6))\n", "sns.histplot(data=df, x='trajectory_magnitude', hue='is_valid', element='step', stat='density', common_norm=False)\n", "plt.title('Distribution of Trajectory Magnitudes')\n", "plt.xlabel('Magnitude')\n", "plt.ylabel('Density')\n", "plt.legend(title='Is Valid')\n", "plt.tight_layout()\n", "plt.tight_layout()\n", "plt.savefig('trajectories_magntude_plot.png', dpi=300, bbox_inches='tight')\n", "plt.show()" ], "metadata": { "id": "tYVhJJbPwNxo" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "plt.figure(figsize=(12, 6))\n", "\n", "# Define colors explicitly\n", "colors = {'Valid': 'blue', 'Invalid': 'red'}\n", "\n", "# Create a new DataFrame with the data for plotting\n", "plot_data = pd.DataFrame({\n", " 'Hamiltonian Energy': df['H_energy'],\n", " 'Validity': df['is_valid'].map({True: 'Valid', False: 'Invalid'})\n", "})\n", "\n", "# Create the histogram plot with explicit colors\n", "sns.histplot(data=plot_data, x='Hamiltonian Energy', hue='Validity',\n", " element='step', stat='density', common_norm=False,\n", " palette=colors)\n", "\n", "plt.title('Distribution of Refined Hamiltonian Energy', fontsize=16)\n", "plt.xlabel('Hamiltonian Energy', fontsize=14)\n", "plt.ylabel('Density', fontsize=14)\n", "\n", "# Adjust legend\n", "plt.legend(title='Chain Validity', title_fontsize='13', fontsize='12')\n", "\n", "# Add vertical lines for mean energies\n", "plt.axvline(x=-60.889, color='blue', linestyle='--', label='Mean Valid')\n", "plt.axvline(x=-53.816, color='red', linestyle='--', label='Mean Invalid')\n", "\n", "# Add text annotations for mean energies\n", "plt.text(-60.889, plt.gca().get_ylim()[1], 'Mean Valid',\n", " rotation=90, va='top', ha='right', color='blue')\n", "plt.text(-53.816, plt.gca().get_ylim()[1], 'Mean Invalid',\n", " rotation=90, va='top', ha='left', color='red')\n", "\n", "plt.tight_layout()\n", "plt.savefig('refined_hamiltonian_energy_distribution.png', dpi=300, bbox_inches='tight')\n", "plt.show()" ], "metadata": { "id": "m1fHZ-NpMnHD" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Perform PCA to reduce to 2 dimensions\n", "pca = PCA(n_components=2)\n", "trajectories_2d = pca.fit_transform(np.vstack(trajectories))\n", "\n", "# Reshape the data back into trajectories\n", "trajectories_2d = trajectories_2d.reshape(len(trajectories), -1, 2)\n", "\n", "# Create the plot\n", "plt.figure(figsize=(12, 10))\n", "plt.style.use('seaborn-whitegrid')\n", "sns.set_context(\"paper\")\n", "plt.rcParams['font.family'] = 'serif'\n", "plt.rcParams['font.serif'] = ['Times New Roman'] + plt.rcParams['font.serif']\n", "\n", "# Plot trajectories\n", "valid_trajectories = []\n", "invalid_trajectories = []\n", "for i, traj in enumerate(trajectories_2d[:100]): # Limit to 100 for clarity\n", " if df.iloc[i]['is_valid']:\n", " valid_trajectories.append(traj)\n", " color = 'green'\n", " else:\n", " invalid_trajectories.append(traj)\n", " color = 'red'\n", " plt.plot(traj[:, 0], traj[:, 1], color=color, alpha=0.5)\n", " plt.scatter(traj[0, 0], traj[0, 1], color=color, s=20, marker='o')\n", " plt.scatter(traj[-1, 0], traj[-1, 1], color=color, s=20, marker='s')\n", "\n", "# Calculate the vector field based on the average direction of trajectories\n", "grid_size = 20\n", "x = np.linspace(trajectories_2d[:, :, 0].min(), trajectories_2d[:, :, 0].max(), grid_size)\n", "y = np.linspace(trajectories_2d[:, :, 1].min(), trajectories_2d[:, :, 1].max(), grid_size)\n", "X, Y = np.meshgrid(x, y)\n", "\n", "U = np.zeros_like(X)\n", "V = np.zeros_like(Y)\n", "\n", "for i in range(grid_size):\n", " for j in range(grid_size):\n", " nearby_trajectories = [traj for traj in trajectories_2d if\n", " (x[i]-0.5 < traj[:, 0]).any() and (traj[:, 0] < x[i]+0.5).any() and\n", " (y[j]-0.5 < traj[:, 1]).any() and (traj[:, 1] < y[j]+0.5).any()]\n", " if nearby_trajectories:\n", " directions = np.diff(nearby_trajectories, axis=1)\n", " avg_direction = np.mean(directions, axis=(0, 1))\n", " U[j, i], V[j, i] = avg_direction\n", "\n", "# Normalize the vector field\n", "magnitude = np.sqrt(U**2 + V**2)\n", "U = U / np.where(magnitude > 0, magnitude, 1)\n", "V = V / np.where(magnitude > 0, magnitude, 1)\n", "\n", "plt.streamplot(X, Y, U, V, density=1, color='gray', linewidth=0.5, arrowsize=0.5)\n", "\n", "# Find key points using KMeans clustering\n", "n_clusters = 5 # Adjust this number based on how many key points you want\n", "kmeans = KMeans(n_clusters=n_clusters)\n", "flattened_trajectories = trajectories_2d.reshape(-1, 2)\n", "kmeans.fit(flattened_trajectories)\n", "key_points = kmeans.cluster_centers_\n", "\n", "# Plot key points\n", "plt.scatter(key_points[:, 0], key_points[:, 1], color='blue', s=100, marker='*', zorder=5)\n", "\n", "# Add labels to key points\n", "for i, point in enumerate(key_points):\n", " plt.annotate(f'Key Point {i+1}', (point[0], point[1]), xytext=(5, 5),\n", " textcoords='offset points', fontsize=8, color='blue')\n", "\n", "# Add labels and title\n", "plt.xlabel('PCA 1')\n", "plt.ylabel('PCA 2')\n", "plt.title('2D Reasoning Trajectories with Phase Space Features and Key Points')\n", "\n", "# Add a legend\n", "valid_line = plt.Line2D([], [], color='green', label='Valid Chains')\n", "invalid_line = plt.Line2D([], [], color='red', label='Invalid Chains')\n", "vector_field_line = plt.Line2D([], [], color='gray', label='Vector Field')\n", "key_point_marker = plt.Line2D([], [], color='blue', marker='*', linestyle='None',\n", " markersize=10, label='Key Points')\n", "plt.legend(handles=[valid_line, invalid_line, vector_field_line, key_point_marker])\n", "\n", "# Show the plot\n", "plt.tight_layout()\n", "plt.savefig('2d_reasoning_trajectories_with_key_points.png', dpi=300, bbox_inches='tight')\n", "plt.show()" ], "metadata": { "id": "m38JkWLcQKCc" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "fig = plt.figure(figsize=(10, 8))\n", "ax = fig.add_subplot(111, projection='3d')\n", "\n", "for i, trajectory in enumerate(trajectories_3d[:100]): # Limit to first 100 for clarity\n", " color = 'green' if df.iloc[i]['is_valid'] else 'red'\n", " ax.plot(trajectory[:, 0], trajectory[:, 1], trajectory[:, 2], color=color, alpha=0.5)\n", " ax.scatter(trajectory[0, 0], trajectory[0, 1], trajectory[0, 2], color=color, s=20)\n", " ax.scatter(trajectory[-1, 0], trajectory[-1, 1], trajectory[-1, 2], color=color, s=20, marker='s')\n", "\n", "ax.set_xlabel('PCA 1')\n", "ax.set_ylabel('PCA 2')\n", "ax.set_zlabel('PCA 3')\n", "ax.set_title('Reasoning Trajectories in 3D Embedding Space')\n", "plt.tight_layout()\n", "plt.show()" ], "metadata": { "id": "nVVADjWNNVy_" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "def compute_vector_field(trajectories, grid_size=10):\n", " # Determine the bounds of the space\n", " all_points = np.vstack(trajectories)\n", " mins = np.min(all_points, axis=0)\n", " maxs = np.max(all_points, axis=0)\n", "\n", " # Create a grid\n", " x = np.linspace(mins[0], maxs[0], grid_size)\n", " y = np.linspace(mins[1], maxs[1], grid_size)\n", " z = np.linspace(mins[2], maxs[2], grid_size)\n", " X, Y, Z = np.meshgrid(x, y, z)\n", "\n", " U = np.zeros((grid_size, grid_size, grid_size))\n", " V = np.zeros((grid_size, grid_size, grid_size))\n", " W = np.zeros((grid_size, grid_size, grid_size))\n", "\n", " # Compute average direction for each grid cell\n", " for trajectory in trajectories:\n", " directions = np.diff(trajectory, axis=0)\n", " for direction, point in zip(directions, trajectory[:-1]):\n", " i, j, k = np.floor((point - mins) / (maxs - mins) * (grid_size - 1)).astype(int)\n", " U[i, j, k] += direction[0]\n", " V[i, j, k] += direction[1]\n", " W[i, j, k] += direction[2]\n", "\n", " # Normalize\n", " magnitude = np.sqrt(U**2 + V**2 + W**2)\n", " U /= np.where(magnitude > 0, magnitude, 1)\n", " V /= np.where(magnitude > 0, magnitude, 1)\n", " W /= np.where(magnitude > 0, magnitude, 1)\n", "\n", " return X, Y, Z, U, V, W\n", "\n", "# Set up the figure and 3D axis\n", "fig = plt.figure(figsize=(12, 10))\n", "ax = fig.add_subplot(111, projection='3d')\n", "\n", "# Plot trajectories\n", "for i, trajectory in enumerate(trajectories_3d[:100]): # Limit to first 100 for clarity\n", " color = 'green' if df.iloc[i]['is_valid'] else 'red'\n", " ax.plot(trajectory[:, 0], trajectory[:, 1], trajectory[:, 2], color=color, alpha=0.5)\n", " ax.scatter(trajectory[0, 0], trajectory[0, 1], trajectory[0, 2], color=color, s=20)\n", " ax.scatter(trajectory[-1, 0], trajectory[-1, 1], trajectory[-1, 2], color=color, s=20, marker='s')\n", "\n", "# Compute and plot vector field\n", "X, Y, Z, U, V, W = compute_vector_field(trajectories_3d[:100])\n", "ax.quiver(X, Y, Z, U, V, W, length=0.5, normalize=True, color='blue', alpha=0.3)\n", "\n", "ax.set_xlabel('PCA 1')\n", "ax.set_ylabel('PCA 2')\n", "ax.set_zlabel('PCA 3')\n", "ax.set_title('Reasoning Trajectories and Phase Space in 3D Embedding Space')\n", "\n", "plt.tight_layout()\n", "plt.savefig('3d_phase_space_plot.png', dpi=300, bbox_inches='tight')\n", "plt.show()" ], "metadata": { "id": "l0UmPM8xftuv" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "plt.figure(figsize=(10, 6))\n", "\n", "# Create the histogram plot\n", "sns.histplot(data=df, x='energy_conservation', kde=True, bins=50, color='green')\n", "\n", "# Set the title and labels\n", "plt.title(\"Distribution of Energy Conservation Scores\", fontsize=16)\n", "plt.xlabel(\"Energy Conservation Score\", fontsize=12)\n", "plt.ylabel(\"Frequency\", fontsize=12)\n", "\n", "# Adjust layout and display\n", "plt.tight_layout()\n", "plt.savefig('energy_conservation_distribution.png', dpi=300, bbox_inches='tight')\n", "plt.show()" ], "metadata": { "id": "qca1p7PhOaU6" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))\n", "\n", "sns.histplot(data=df, x='trajectory_magnitude', hue='is_valid', element='step', stat='density', common_norm=False, ax=ax1)\n", "ax1.set_title('Distribution of Trajectory Magnitudes')\n", "ax1.set_xlabel('Magnitude')\n", "ax1.set_ylabel('Density')\n", "\n", "sns.histplot(data=df, x='trajectory_angle', hue='is_valid', element='step', stat='density', common_norm=False, ax=ax2)\n", "ax2.set_title('Distribution of Trajectory Angles')\n", "ax2.set_xlabel('Angle (radians)')\n", "ax2.set_ylabel('Density')\n", "\n", "plt.tight_layout()\n", "plt.savefig('magnitude_angle_distribution.png', dpi=300, bbox_inches='tight')\n", "plt.close()" ], "metadata": { "id": "I8VrMb6MMsOc" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Additional analysis\n", "print(f\"Average Energy Conservation Score: {df['energy_conservation'].mean():.4f}\")\n", "print(f\"Correlation between Energy Conservation and Validity: {df['energy_conservation'].corr(df['is_valid']):.4f}\")\n", "print(f\"Average Hamiltonian Energy for Valid Chains: {valid_chains['H_energy'].mean():.4f}\")\n", "print(f\"Average Hamiltonian Energy for Invalid Chains: {invalid_chains['H_energy'].mean():.4f}\")\n", "\n", "# T-test for difference in Hamiltonian Energy\n", "t_stat, p_value = stats.ttest_ind(valid_chains['H_energy'], invalid_chains['H_energy'])\n", "print(f\"\\nT-test for difference in Hamiltonian Energy:\")\n", "print(f\"t-statistic: {t_stat:.4f}\")\n", "print(f\"p-value: {p_value:.4f}\")" ], "metadata": { "id": "FHmMSmNAI-qc" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Geometric analysis" ], "metadata": { "id": "1s_DosZEWVhy" } }, { "cell_type": "code", "source": [ "fig = plt.figure(figsize=(10, 8))\n", "ax = fig.add_subplot(111, projection='3d')\n", "\n", "for i, trajectory in enumerate(trajectories_3d[:100]): # Limit to first 100 for clarity\n", " color = 'green' if df.iloc[i]['is_valid'] else 'red'\n", " ax.plot(trajectory[:, 0], trajectory[:, 1], trajectory[:, 2], color=color, alpha=0.5)\n", " ax.scatter(trajectory[0, 0], trajectory[0, 1], trajectory[0, 2], color=color, s=20)\n", " ax.scatter(trajectory[-1, 0], trajectory[-1, 1], trajectory[-1, 2], color=color, s=20, marker='s')\n", "\n", "ax.set_xlabel('PCA 1')\n", "ax.set_ylabel('PCA 2')\n", "ax.set_zlabel('PCA 3')\n", "ax.set_title('Reasoning Trajectories in 3D Embedding Space')\n", "plt.tight_layout()\n", "plt.savefig('3d_trajectories.png', dpi=300, bbox_inches='tight')\n", "plt.close()\n", "\n", "# 2. Trajectory Energy by Chain Index\n", "plt.figure(figsize=(10, 6))\n", "sns.scatterplot(x=df.index, y=trajectory_energies, hue=df['is_valid'], palette={True: 'green', False: 'red'})\n", "plt.title('Trajectory Energy by Chain Index')\n", "plt.xlabel('Chain Index')\n", "plt.ylabel('Energy')\n", "plt.legend(title='Is Valid')\n", "plt.tight_layout()\n", "plt.savefig('trajectory_energy.png', dpi=300, bbox_inches='tight')\n", "plt.close()" ], "metadata": { "id": "2Sz-nqGA9p8B" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Energy Plot\n", "plt.figure(figsize=(12, 6))\n", "sns.scatterplot(x=df.index, y=trajectory_energies, hue=df['is_valid'], palette={True: 'green', False: 'red'})\n", "plt.title('Trajectory Energy by Chain Index')\n", "plt.xlabel('Chain Index')\n", "plt.ylabel('Energy')\n", "plt.legend(title='Is Valid')\n", "plt.tight_layout()\n", "plt.show()" ], "metadata": { "id": "5rN0K7tM_68P" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "plt.figure(figsize=(12, 6))\n", "\n", "# Define colors explicitly\n", "colors = {'Valid': 'green', 'Invalid': 'red'}\n", "\n", "# Create the histogram plot with explicit colors\n", "sns.histplot(data=pd.DataFrame({'Energy': trajectory_energies, 'Is Valid': df['is_valid'].map({True: 'Valid', False: 'Invalid'})}),\n", " x='Energy', hue='Is Valid', element='step', stat='density', common_norm=False,\n", " palette=colors)\n", "\n", "plt.title('Distribution of Trajectory Energies', fontsize=16)\n", "plt.xlabel('Energy', fontsize=14)\n", "plt.ylabel('Density', fontsize=14)\n", "\n", "# Create a custom legend\n", "handles = [plt.Rectangle((0,0),1,1, color=color) for color in colors.values()]\n", "labels = list(colors.keys())\n", "plt.legend(handles, labels, title='Trajectory Validity', title_fontsize='13', fontsize='12')\n", "\n", "plt.tight_layout()\n", "plt.savefig('energy_distribution_plot.png', dpi=300, bbox_inches='tight')\n", "plt.show()" ], "metadata": { "id": "iRG8GKRF__3a" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Distribution of Trajectory Magnitudes and Angles\n", "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))\n", "\n", "sns.histplot(data=df, x='trajectory_magnitude', hue='is_valid', element='step', stat='density', common_norm=False, ax=ax1)\n", "ax1.set_title('Distribution of Trajectory Magnitudes')\n", "ax1.set_xlabel('Magnitude')\n", "ax1.set_ylabel('Density')\n", "\n", "sns.histplot(data=df, x='trajectory_angle', hue='is_valid', element='step', stat='density', common_norm=False, ax=ax2)\n", "ax2.set_title('Distribution of Trajectory Angles')\n", "ax2.set_xlabel('Angle (radians)')\n", "ax2.set_ylabel('Density')\n", "\n", "plt.tight_layout()\n", "plt.savefig('magnitude_angle_distribution.png', dpi=300, bbox_inches='tight')\n", "plt.close()" ], "metadata": { "id": "yLJie7VYoas6" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Trajectory Magnitude vs Angle\n", "plt.figure(figsize=(10, 8))\n", "sns.scatterplot(data=df, x='trajectory_angle', y='trajectory_magnitude', hue='is_valid', alpha=0.6)\n", "plt.title('Trajectory Magnitude vs Angle')\n", "plt.xlabel('Angle (radians)')\n", "plt.ylabel('Magnitude')\n", "plt.legend(title='Is Valid')\n", "plt.tight_layout()\n", "plt.savefig('magnitude_vs_angle.png', dpi=300, bbox_inches='tight')\n", "plt.close()\n", "\n", "# 6. Trajectory Properties Comparison\n", "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))\n", "\n", "sns.boxplot(x='is_valid', y='length', data=traj_properties, ax=ax1)\n", "ax1.set_title('Trajectory Length')\n", "ax1.set_xlabel('Is Valid')\n", "ax1.set_ylabel('Length')\n", "\n", "sns.boxplot(x='is_valid', y='smoothness', data=traj_properties, ax=ax2)\n", "ax2.set_title('Trajectory Smoothness')\n", "ax2.set_xlabel('Is Valid')\n", "ax2.set_ylabel('Smoothness')\n", "\n", "plt.tight_layout()\n", "plt.savefig('trajectory_properties.png', dpi=300, bbox_inches='tight')\n", "plt.close()" ], "metadata": { "id": "OOasgefio41H" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "plt.figure(figsize=(12, 8))\n", "\n", "# Define colors explicitly\n", "colors = {'Valid': 'blue', 'Invalid': 'red'}\n", "\n", "# Prepare the data\n", "plot_data = df.copy()\n", "plot_data['Validity'] = df['is_valid'].map({True: 'Valid', False: 'Invalid'})\n", "\n", "# Create the scatter plot with explicit colors\n", "sns.scatterplot(data=plot_data, x='trajectory_angle', y='trajectory_magnitude', hue='Validity',\n", " palette=colors, alpha=0.6)\n", "\n", "plt.title('Trajectory Magnitude vs Angle', fontsize=16)\n", "plt.xlabel('Angle (radians)', fontsize=14)\n", "plt.ylabel('Magnitude', fontsize=14)\n", "\n", "# Create custom legend handles\n", "handles = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color, markersize=10, alpha=0.6)\n", " for color in colors.values()]\n", "labels = list(colors.keys())\n", "\n", "# Add the legend with custom handles\n", "plt.legend(handles, labels, title='Chain Validity', title_fontsize='13', fontsize='12')\n", "\n", "plt.tight_layout()\n", "plt.savefig('refined_magnitude_vs_angle_plot.png', dpi=300, bbox_inches='tight')\n", "plt.show()\n", "\n", "# Calculate and print statistical information\n", "valid_data = df[df['is_valid']]\n", "invalid_data = df[~df['is_valid']]\n", "\n", "print(\"Statistical Information:\")\n", "print(f\"Correlation between Angle and Magnitude (overall): {df['trajectory_angle'].corr(df['trajectory_magnitude']):.3f}\")\n", "print(f\"Correlation for Valid Chains: {valid_data['trajectory_angle'].corr(valid_data['trajectory_magnitude']):.3f}\")\n", "print(f\"Correlation for Invalid Chains: {invalid_data['trajectory_angle'].corr(invalid_data['trajectory_magnitude']):.3f}\")\n", "\n", "# Perform t-tests\n", "t_stat_angle, p_value_angle = stats.ttest_ind(valid_data['trajectory_angle'], invalid_data['trajectory_angle'])\n", "t_stat_mag, p_value_mag = stats.ttest_ind(valid_data['trajectory_magnitude'], invalid_data['trajectory_magnitude'])\n", "\n", "print(\"\\nT-test for difference in Trajectory Angle:\")\n", "print(f\"t-statistic: {t_stat_angle:.4f}\")\n", "print(f\"p-value: {p_value_angle:.4f}\")\n", "\n", "print(\"\\nT-test for difference in Trajectory Magnitude:\")\n", "print(f\"t-statistic: {t_stat_mag:.4f}\")\n", "print(f\"p-value: {p_value_mag:.4f}\")\n", "\n", "# Calculate and print mean values\n", "print(\"\\nMean Values:\")\n", "print(f\"Mean Angle for Valid Chains: {valid_data['trajectory_angle'].mean():.3f}\")\n", "print(f\"Mean Angle for Invalid Chains: {invalid_data['trajectory_angle'].mean():.3f}\")\n", "print(f\"Mean Magnitude for Valid Chains: {valid_data['trajectory_magnitude'].mean():.3f}\")\n", "print(f\"Mean Magnitude for Invalid Chains: {invalid_data['trajectory_magnitude'].mean():.3f}\")" ], "metadata": { "id": "6pBMYGiKBR7f" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Statistical tests\n", "valid_mag = df[df['is_valid']]['trajectory_magnitude']\n", "invalid_mag = df[~df['is_valid']]['trajectory_magnitude']\n", "mag_ttest = ttest_ind(valid_mag, invalid_mag)\n", "\n", "valid_ang = df[df['is_valid']]['trajectory_angle']\n", "invalid_ang = df[~df['is_valid']]['trajectory_angle']\n", "ang_ttest = ttest_ind(valid_ang, invalid_ang)\n", "\n", "print(\"T-test for trajectory magnitude:\", mag_ttest)\n", "print(\"T-test for trajectory angle:\", ang_ttest)\n", "\n", "# Correlation with energy\n", "mag_energy_corr = df['trajectory_magnitude'].corr(df['H_energy'])\n", "ang_energy_corr = df['trajectory_angle'].corr(df['H_energy'])\n", "\n", "print(\"Correlation between magnitude and H energy:\", mag_energy_corr)\n", "print(\"Correlation between angle and H energy:\", ang_energy_corr)" ], "metadata": { "id": "i2ccr--MBXYa" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "def calculate_curvature(trajectory):\n", " # Assuming trajectory has 3 points: start, middle, end\n", "\n", " a = np.linalg.norm(trajectory[0][1] - trajectory[0][0])\n", " b = np.linalg.norm(trajectory[0][2] - trajectory[0][1])\n", " c = np.linalg.norm(trajectory[0][2] - trajectory[0][0])\n", "\n", " s = (a + b + c) / 2\n", " area = np.sqrt(s * (s-a) * (s-b) * (s-c))\n", "\n", " return 4 * area / (a * b * c)\n", "\n", "def calculate_rate_of_change(trajectory):\n", " # Calculate the rate of change between each pair of consecutive points\n", " changes = np.diff(trajectory, axis=0)\n", " rates = np.linalg.norm(changes, axis=1)\n", " return np.mean(rates)\n", "\n", "# Calculate curvature and rate of change\n", "curvatures = []\n", "rates_of_change = []\n", "\n", "for traj in trajectories_3d:\n", " curvatures.append(calculate_curvature(traj))\n", " rates_of_change.append(calculate_rate_of_change(traj))\n", "\n", "# Add these to the dataframe\n", "df['curvature'] = curvatures\n", "df['rate_of_change'] = rates_of_change\n", "\n", "\n", "plt.figure(figsize=(12, 6))\n", "\n", "# Define colors explicitly\n", "colors = {'Valid': 'blue', 'Invalid': 'red'}\n", "\n", "# Prepare the data\n", "plot_data = pd.DataFrame({\n", " 'Curvature': df['curvature'],\n", " 'Validity': df['is_valid'].map({True: 'Valid', False: 'Invalid'})\n", "})\n", "\n", "# Create the histogram plot with explicit colors\n", "sns.histplot(data=plot_data, x='Curvature', hue='Validity',\n", " element='step', stat='density', common_norm=False,\n", " palette=colors)\n", "\n", "plt.title('Distribution of Trajectory Curvatures', fontsize=16)\n", "plt.xlabel('Curvature', fontsize=14)\n", "plt.ylabel('Density', fontsize=14)\n", "\n", "# Adjust legend\n", "plt.legend(title='Chain Validity', title_fontsize='13', fontsize='12')\n", "\n", "# Calculate mean curvatures for valid and invalid chains\n", "mean_valid = df[df['is_valid']]['curvature'].mean()\n", "mean_invalid = df[~df['is_valid']]['curvature'].mean()\n", "\n", "# Add vertical lines for mean curvatures\n", "plt.axvline(x=mean_valid, color='blue', linestyle='--', label='Mean Valid')\n", "plt.axvline(x=mean_invalid, color='red', linestyle='--', label='Mean Invalid')\n", "\n", "# Add text annotations for mean curvatures\n", "plt.text(mean_valid, plt.gca().get_ylim()[1], f'Mean Valid: {mean_valid:.3f}',\n", " rotation=90, va='top', ha='right', color='blue')\n", "plt.text(mean_invalid, plt.gca().get_ylim()[1], f'Mean Invalid: {mean_invalid:.3f}',\n", " rotation=90, va='top', ha='left', color='red')\n", "\n", "plt.tight_layout()\n", "plt.savefig('refined_curvature_distribution.png', dpi=300, bbox_inches='tight')\n", "plt.show()\n", "\n", "# Calculate and print statistical information\n", "valid_curv = df[df['is_valid']]['curvature']\n", "invalid_curv = df[~df['is_valid']]['curvature']\n", "t_stat, p_value = stats.ttest_ind(valid_curv, invalid_curv)" ], "metadata": { "id": "BlXQkEKjCrSK" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "plt.figure(figsize=(12, 6))\n", "\n", "# Define colors explicitly\n", "colors = {'Valid': 'blue', 'Invalid': 'red'}\n", "\n", "# Prepare the data\n", "plot_data = pd.DataFrame({\n", " 'Rate of Change': df['rate_of_change'],\n", " 'Validity': df['is_valid'].map({True: 'Valid', False: 'Invalid'})\n", "})\n", "\n", "# Create the histogram plot with explicit colors\n", "sns.histplot(data=plot_data, x='Rate of Change', hue='Validity',\n", " element='step', stat='density', common_norm=False,\n", " palette=colors)\n", "\n", "plt.title('Distribution of Trajectory Rates of Change', fontsize=16)\n", "plt.xlabel('Rate of Change', fontsize=14)\n", "plt.ylabel('Density', fontsize=14)\n", "\n", "# Create custom legend handles\n", "handles = [plt.Rectangle((0,0),1,1, color=colors[label]) for label in colors]\n", "labels = list(colors.keys())\n", "\n", "# Add the legend with custom handles\n", "plt.legend(handles, labels, title='Chain Validity', title_fontsize='13', fontsize='12')\n", "\n", "plt.tight_layout()\n", "plt.savefig('simplified_rate_of_change_distribution.png', dpi=300, bbox_inches='tight')\n", "plt.show()\n", "\n", "# Calculate and print statistical information\n", "valid_roc = df[df['is_valid']]['rate_of_change']\n", "invalid_roc = df[~df['is_valid']]['rate_of_change']\n", "t_stat, p_value = stats.ttest_ind(valid_roc, invalid_roc)\n", "\n", "mean_valid = valid_roc.mean()\n", "mean_invalid = invalid_roc.mean()\n", "\n", "print(\"Distribution of Trajectory Rates of Change\")\n", "print(f\"Average Rate of Change for Valid Chains: {mean_valid:.3f}\")\n", "print(f\"Average Rate of Change for Invalid Chains: {mean_invalid:.3f}\")\n", "print(f\"Correlation between Rate of Change and Validity: {df['rate_of_change'].corr(df['is_valid']):.3f}\")\n", "print(\"\\nT-test for difference in Rate of Change:\")\n", "print(f\"t-statistic: {t_stat:.4f}\")\n", "print(f\"p-value: {p_value:.4f}\")" ], "metadata": { "id": "T7GzkWJzCwJe" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Statistical tests\n", "df['curvature'] = df['curvature'].fillna(0)\n", "df['rate_of_change'] = df['rate_of_change'].astype(float)\n", "valid_curv = df[df['is_valid']]['curvature']\n", "invalid_curv = df[~df['is_valid']]['curvature']\n", "curv_ttest = ttest_ind(valid_curv, invalid_curv)\n", "\n", "valid_roc = df[df['is_valid']]['rate_of_change']\n", "invalid_roc = df[~df['is_valid']]['rate_of_change']\n", "roc_ttest = ttest_ind(valid_roc, invalid_roc)\n", "\n", "print(\"T-test for trajectory curvature:\", curv_ttest)\n", "print(\"T-test for trajectory rate of change:\", roc_ttest)\n", "\n", "# Correlation with energy\n", "curv_energy_corr = df['curvature'].corr(df['H_energy'])\n", "roc_energy_corr = df['rate_of_change'].corr(df['H_energy'])\n", "\n", "print(\"Correlation between curvature and energy:\", curv_energy_corr)\n", "print(\"Correlation between rate of change and energy:\", roc_energy_corr)" ], "metadata": { "id": "0PabrOYpC7dK" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Frenet's framework\n", "def reduce_dimensionality(trajectories, n_components=3):\n", " \"\"\"Reduce dimensionality of trajectories using PCA\"\"\"\n", " flattened = np.vstack(trajectories)\n", " pca = PCA(n_components=n_components)\n", " reduced = pca.fit_transform(flattened)\n", " return reduced.reshape(len(trajectories), -1, n_components), pca\n", "\n", "def frenet_serret_frame(trajectory):\n", " \"\"\"Compute Frenet-Serret frame for a trajectory\"\"\"\n", " # Compute tangent vectors\n", " T = np.diff(trajectory, axis=0)\n", " T_norm = np.linalg.norm(T, axis=1, keepdims=True)\n", " T = np.divide(T, T_norm, where=T_norm!=0)\n", "\n", " # Compute normal vectors\n", " N = np.diff(T, axis=0)\n", " N_norm = np.linalg.norm(N, axis=1, keepdims=True)\n", " N = np.divide(N, N_norm, where=N_norm!=0)\n", "\n", " # Compute binormal vectors\n", " B = np.cross(T[:-1], N)\n", "\n", " return T[:-1], N, B\n", "\n", "def compute_curvature_torsion(T, N, B):\n", " \"\"\"Compute curvature and torsion from Frenet-Serret frame\"\"\"\n", " dT = np.diff(T, axis=0)\n", " curvature = np.linalg.norm(dT, axis=1)\n", "\n", " # Compute torsion\n", " dB = np.diff(B, axis=0)\n", " torsion = np.sum(dB * N[1:], axis=1)\n", "\n", " return np.mean(curvature), np.mean(torsion)\n", "\n", "# Reduce dimensionality of trajectories\n", "reduced_trajectories, pca = reduce_dimensionality(trajectories)\n", "\n", "# Compute Frenet-Serret frames and curvature/torsion\n", "curvatures = []\n", "torsions = []\n", "for i, traj in enumerate(reduced_trajectories):\n", " try:\n", " T, N, B = frenet_serret_frame(traj)\n", " curvature, torsion = compute_curvature_torsion(T, N, B)\n", " curvatures.append(curvature)\n", " torsions.append(torsion)\n", " except Exception as e:\n", " print(f\"Error processing trajectory {i}: {str(e)}\")\n", " print(f\"Trajectory shape: {traj.shape}\")\n", " curvatures.append(np.nan)\n", " torsions.append(np.nan)\n", "\n", "df['curvature'] = curvatures\n", "df['torsion'] = torsions\n", "\n", "# Remove any NaN values\n", "df = df.dropna(subset=['curvature', 'torsion'])\n" ], "metadata": { "id": "hgpHHxRz438n" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Analyze the principal components\n", "explained_variance_ratio = pca.explained_variance_ratio_\n", "cumulative_variance_ratio = np.cumsum(explained_variance_ratio)\n", "\n", "plt.figure(figsize=(10, 6))\n", "plt.plot(range(1, len(explained_variance_ratio) + 1), cumulative_variance_ratio, 'bo-')\n", "plt.xlabel('Number of Components', fontsize=14)\n", "plt.ylabel('Cumulative Explained Variance Ratio', fontsize=14)\n", "plt.title('Explained Variance Ratio by Principal Components', fontsize=16)\n", "plt.savefig('pca_explained_variance.png', dpi=300, bbox_inches='tight')\n", "plt.show()\n", "\n", "print(f\"Explained variance ratio of first 3 components: {explained_variance_ratio[:3]}\")\n", "print(f\"Cumulative explained variance ratio of first 3 components: {cumulative_variance_ratio[2]:.4f}\")" ], "metadata": { "id": "UHASmPhm5dsa" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Compute and visualize Hamiltonian along trajectories\n", "\n", "def hamiltonian(q, p, q_goal):\n", " \"\"\"Hamiltonian function\"\"\"\n", " T = 0.5 * np.dot(p, p) # Kinetic energy\n", " V = sophisticated_potential(q, q_goal) # Potential energy\n", " return T + V\n", "\n", "def sophisticated_potential(q, q_goal):\n", " \"\"\"A more sophisticated potential energy function\"\"\"\n", " similarity = np.dot(q, q_goal) / (np.linalg.norm(q) * np.linalg.norm(q_goal))\n", " complexity = np.linalg.norm(q) # Assume more complex states have higher norm\n", " return -similarity + 0.1 * complexity # Balance between relevance and complexity\n", "\n", "# Compute and visualize Hamiltonian along trajectories\n", "hamiltonians = []\n", "q_goal = np.mean([traj[-1] for traj in trajectories], axis=0) # Assuming the goal is the average final state\n", "\n", "for traj in trajectories:\n", " H = []\n", " for i in range(len(traj)):\n", " q = traj[i]\n", " p = traj[i] - traj[i-1] if i > 0 else np.zeros_like(q) # Estimate momentum as the difference between states\n", " H.append(hamiltonian(q, p, q_goal))\n", " hamiltonians.append(H)\n", "\n", "plt.figure(figsize=(12, 6))\n", "for i, H in enumerate(hamiltonians[:20]): # Plot first 20 for clarity\n", " plt.plot(H, label=f'Trajectory {i+1}')\n", "plt.title('Hamiltonian Evolution Along Reasoning Trajectories', fontsize=16)\n", "plt.xlabel('Time Step', fontsize=16)\n", "plt.ylabel('Hamiltonian',fontsize=16)\n", "plt.legend()\n", "plt.savefig('hamiltonian_evolution_plot.png', dpi=300, bbox_inches='tight')\n", "plt.show()\n", "\n", "# Statistical analysis\n", "valid_curvature = df[df['is_valid']]['curvature']\n", "invalid_curvature = df[~df['is_valid']]['curvature']\n", "t_stat, p_value = stats.ttest_ind(valid_curvature, invalid_curvature)\n", "\n", "print(f\"T-test for curvature: t-statistic = {t_stat}, p-value = {p_value}\")\n", "\n", "# Correlation analysis\n", "correlation = df['curvature'].corr(df['torsion'])\n", "print(f\"Correlation between curvature and torsion: {correlation}\")\n", "\n" ], "metadata": { "id": "v0V1WiVN6F6g" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# 3D plot of trajectories\n", "fig = plt.figure(figsize=(12,12))\n", "ax = fig.add_subplot(111, projection='3d')\n", "\n", "for i, traj in enumerate(trajectories_3d[:20]): # Plot first 20 for clarity\n", " color = 'green' if df.iloc[i]['is_valid'] else 'red'\n", " ax.plot(traj[:, 0], traj[:, 1], traj[:, 2], color=color, alpha=0.6)\n", "\n", "ax.set_xlabel('PCA 1', fontsize=14)\n", "ax.set_ylabel('PCA 2', fontsize=14)\n", "ax.set_zlabel('PCA 3', fontsize=14)\n", "ax.set_title('Reasoning Trajectories in PCA Space', fontsize=16)\n", "# Add legend\n", "ax.legend([valid_handle, invalid_handle], ['Valid', 'Invalid'], loc='upper right')\n", "plt.savefig('pca_trajectories_plot.png', dpi=300, bbox_inches='tight')\n", "plt.show()" ], "metadata": { "id": "7BuXJCesA-2u" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Statistical Analysis\n", "\n", "pca_means = np.array([traj.mean(axis=0) for traj in trajectories_3d])\n", "X = pd.DataFrame(pca_means, columns=['PCA1', 'PCA2', 'PCA3'])\n", "y = pd.Series(df['is_valid'].values, name='is_valid')\n", "\n", "# Ensure 'is_valid' is boolean\n", "y = y.astype(bool)\n", "\n", "# Combine X and y into a single DataFrame\n", "data = pd.concat([X, y], axis=1)\n", "\n", "# 1. MANOVA test\n", "manova = MANOVA.from_formula('PCA1 + PCA2 + PCA3 ~ is_valid', data=data)\n", "print(\"MANOVA test results:\")\n", "print(manova.mv_test())\n", "\n", "# 2. T-tests for each PCA dimension\n", "for i in range(3):\n", " t_stat, p_value = stats.ttest_ind(X[f'PCA{i+1}'][y], X[f'PCA{i+1}'][~y])\n", " print(f\"T-test for PCA{i+1}: t-statistic = {t_stat:.4f}, p-value = {p_value:.4f}\")\n", "\n", "# 3. Logistic Regression\n", "log_reg = LogisticRegression()\n", "log_reg.fit(X, y)\n", "y_pred = log_reg.predict(X)\n", "accuracy = accuracy_score(y, y_pred)\n", "print(f\"Logistic Regression Accuracy: {accuracy:.4f}\")\n", "\n", "# 4. Effect sizes (Cohen's d) for each PCA dimension\n", "for i in range(3):\n", " cohens_d = (X[f'PCA{i+1}'][y].mean() - X[f'PCA{i+1}'][~y].mean()) / np.sqrt((X[f'PCA{i+1}'][y].var() + X[f'PCA{i+1}'][~y].var()) / 2)\n", " print(f\"Cohen's d for PCA{i+1}: {cohens_d:.4f}\")\n", "\n", "# 5. Trajectory length comparison\n", "trajectory_lengths = np.array([np.sum(np.sqrt(np.sum(np.diff(traj, axis=0)**2, axis=1))) for traj in trajectories_pca])\n", "t_stat, p_value = stats.ttest_ind(trajectory_lengths[y], trajectory_lengths[~y])\n", "print(f\"T-test for trajectory lengths: t-statistic = {t_stat:.4f}, p-value = {p_value:.4f}\")" ], "metadata": { "id": "rqPocLPzDFiM" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Correlation between trajectory complexity and validity\n", "# Analyze trajectory complexity\n", "def trajectory_complexity(traj):\n", " return np.sum(np.linalg.norm(np.diff(traj, axis=0), axis=1))\n", "\n", "complexities = [trajectory_complexity(traj) for traj in reduced_trajectories]\n", "df['complexity'] = complexities\n", "complexity_correlation = stats.pointbiserialr(df['is_valid'], df['complexity'])\n", "print(f\"Correlation between trajectory complexity and validity: r = {complexity_correlation.correlation:.4f}, p = {complexity_correlation.pvalue:.4f}\")" ], "metadata": { "id": "csICTST5BcS5" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Canonical transformations" ], "metadata": { "id": "c0kKU3xdVpMf" } }, { "cell_type": "code", "source": [ "def hamiltonian(state, t, k):\n", " \"\"\"Simple harmonic oscillator Hamiltonian\"\"\"\n", " q, p = state\n", " return p**2 / 2 + k * q**2 / 2\n", "\n", "def hamilton_equations(state, t, k):\n", " \"\"\"Hamilton's equations for simple harmonic oscillator\"\"\"\n", " q, p = state\n", " dqdt = p\n", " dpdt = -k * q\n", " return [dqdt, dpdt]\n", "\n", "def canonical_transform_to_action_angle(q, p, k):\n", " \"\"\"Transform from (q,p) to action-angle variables (I, theta)\"\"\"\n", " I = (p**2 + k * q**2) / (2 * k)\n", " theta = np.arctan2(np.sqrt(k) * q, p)\n", " return I, theta\n", "\n", "def inverse_canonical_transform(I, theta, k):\n", " \"\"\"Transform from action-angle variables (I, theta) back to (q,p)\"\"\"\n", " q = np.sqrt(2 * I / k) * np.sin(theta)\n", " p = np.sqrt(2 * I * k) * np.cos(theta)\n", " return q, p\n", "\n", "# Parameters\n", "k = 1.0 # Spring constant\n", "t = np.linspace(0, 10, 100)\n", "\n", "# Apply canonical transformation to our trajectories\n", "action_angle_trajectories = []\n", "for traj in trajectories_pca:\n", " q, p = traj[:, 0], traj[:, 1] # Assuming first two PCs represent position and momentum\n", " I, theta = canonical_transform_to_action_angle(q, p, k)\n", " action_angle_trajectories.append(np.column_stack((I, theta)))\n", "\n", "\n", "# Analysis\n", "action_means_valid = [np.mean(traj[:, 0]) for traj, valid in zip(action_angle_trajectories, df['is_valid'].tolist()) if valid]\n", "action_means_nonvalid = [np.mean(traj[:, 0]) for traj, valid in zip(action_angle_trajectories, df['is_valid'].tolist()) if not valid]\n", "angle_ranges_valid = [np.ptp(traj[:, 1]) for traj, valid in zip(action_angle_trajectories, df['is_valid'].tolist()) if valid]\n", "angle_ranges_nonvalid = [np.ptp(traj[:, 1]) for traj, valid in zip(action_angle_trajectories, df['is_valid'].tolist()) if not valid]\n", "\n", "print(f\"Mean action for valid chains: {np.mean(action_means_valid):.4f}\")\n", "print(f\"Mean action for non-valid chains: {np.mean(action_means_nonvalid):.4f}\")\n", "print(f\"Mean angle range for valid chains: {np.mean(angle_ranges_valid):.4f}\")\n", "print(f\"Mean angle range for non-valid chains: {np.mean(angle_ranges_nonvalid):.4f}\")\n", "\n", "# Statistical tests\n", "from scipy import stats\n", "\n", "t_stat, p_value = stats.ttest_ind(action_means_valid, action_means_nonvalid)\n", "print(f\"T-test for action means: t-statistic = {t_stat:.4f}, p-value = {p_value:.4f}\")\n", "\n", "t_stat, p_value = stats.ttest_ind(angle_ranges_valid, angle_ranges_nonvalid)\n", "print(f\"T-test for angle ranges: t-statistic = {t_stat:.4f}, p-value = {p_value:.4f}\")\n", "\n", "# Classify trajectories based on action and angle properties\n", "def classify_trajectory(action, angle_range, valid):\n", " high_action = np.mean(action_means_valid if valid else action_means_nonvalid) + np.std(action_means_valid if valid else action_means_nonvalid)\n", " low_action = np.mean(action_means_valid if valid else action_means_nonvalid) - np.std(action_means_valid if valid else action_means_nonvalid)\n", " high_angle_range = np.mean(angle_ranges_valid if valid else angle_ranges_nonvalid) + np.std(angle_ranges_valid if valid else angle_ranges_nonvalid)\n", "\n", " if action > high_action and angle_range > high_angle_range:\n", " return \"High energy, complex reasoning\"\n", " elif action < low_action and angle_range > high_angle_range:\n", " return \"Low energy, exploratory reasoning\"\n", " elif action > high_action and angle_range <= high_angle_range:\n", " return \"High energy, focused reasoning\"\n", " elif action < low_action and angle_range <= high_angle_range:\n", " return \"Low energy, simple reasoning\"\n", " else:\n", " return \"Moderate reasoning\"" ], "metadata": { "id": "Pm52IjYTXMMH" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Plotting\n", "fig = plt.figure(figsize=(15, 5))\n", "\n", "# Original space\n", "ax1 = fig.add_subplot(131)\n", "for traj, valid in zip(trajectories_pca[:10], df['is_valid'].tolist()[:10]): # Plot first 10 for clarity\n", " color = 'green' if valid else 'red'\n", " ax1.plot(traj[:, 0], traj[:, 1], color=color, alpha=0.7)\n", "ax1.set_xlabel('PC1 (q)', fontsize=12)\n", "ax1.set_ylabel('PC2 (p)', fontsize=12)\n", "ax1.set_title('Original Phase Space', fontsize=14)\n", "ax1.legend([valid_handle, invalid_handle], ['Valid', 'Invalid'], loc='upper right', fontsize=12)\n", "\n", "# Action-Angle space\n", "ax2 = fig.add_subplot(132)\n", "for traj, valid in zip(action_angle_trajectories[:10], df['is_valid'].tolist()[:10]):\n", " color = 'green' if valid else 'red'\n", " ax2.plot(traj[:, 0], traj[:, 1], color=color, alpha=0.7)\n", "ax2.set_xlabel('Action (I)', fontsize=12)\n", "ax2.set_ylabel('Angle (theta)', fontsize=12)\n", "ax2.set_title('Action-Angle Space', fontsize=14)\n", "ax2.legend([valid_handle, invalid_handle], ['Valid', 'Invalid'], loc='upper right', fontsize=12)\n", "\n", "# 3D visualization\n", "ax3 = fig.add_subplot(133, projection='3d')\n", "for traj, valid in zip(action_angle_trajectories[:10], df['is_valid'].tolist()[:10]):\n", " color = 'green' if valid else 'red'\n", " ax3.plot(traj[:, 0], np.cos(traj[:, 1]), np.sin(traj[:, 1]), color=color, alpha=0.7)\n", "ax3.set_xlabel('Action (I)', fontsize=12)\n", "ax3.set_ylabel('cos(theta)', fontsize=12)\n", "ax3.set_zlabel('sin(theta)', fontsize=12)\n", "ax3.set_title('3D Action-Angle Space', fontsize=14)\n", "ax3.legend([valid_handle, invalid_handle], ['Valid', 'Invalid'], loc='upper right', fontsize=12)\n", "\n", "plt.tight_layout()\n", "plt.savefig('canonical_transformation_analysis_with_validity.png', dpi=300, bbox_inches='tight')\n", "plt.show()" ], "metadata": { "id": "YlzvprO0ZBo1" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Conservation laws" ], "metadata": { "id": "b-FE7nQWW1Oe" } }, { "cell_type": "code", "source": [ "def calculate_hamiltonian(q, p):\n", " \"\"\"Simple Hamiltonian function\"\"\"\n", " return 0.5 * (q**2 + p**2)\n", "\n", "def calculate_angular_momentum(q, p):\n", " \"\"\"Angular momentum-like quantity\"\"\"\n", " return q * p\n", "\n", "def calculate_energy_like_quantity(q, p):\n", " \"\"\"Energy-like conserved quantity\"\"\"\n", " return q**2 - p**2\n", "\n", "def analyze_conservation(trajectories, quantity_func, quantity_name):\n", " conserved_scores = []\n", " for traj in trajectories:\n", " q_start, q_end = traj[:, 0]\n", " p_start, p_end = traj[:, 1]\n", " quantity_start = quantity_func(q_start, p_start)\n", " quantity_end = quantity_func(q_end, p_end)\n", " change = abs(quantity_end - quantity_start)\n", " conserved_scores.append(change)\n", " return conserved_scores\n", "\n", "# Analyze conservation for different quantities\n", "hamiltonian_scores = analyze_conservation(trajectories_2d, calculate_hamiltonian, \"Hamiltonian\")\n", "angular_momentum_scores = analyze_conservation(trajectories_2d, calculate_angular_momentum, \"Angular Momentum\")\n", "energy_scores = analyze_conservation(trajectories_2d, calculate_energy_like_quantity, \"Energy-like Quantity\")\n", "\n", "# Print some statistics\n", "print(\"Hamiltonian changes - Mean: {:.4f}, Std: {:.4f}\".format(np.mean(hamiltonian_scores), np.std(hamiltonian_scores)))\n", "print(\"Angular Momentum changes - Mean: {:.4f}, Std: {:.4f}\".format(np.mean(angular_momentum_scores), np.std(angular_momentum_scores)))\n", "print(\"Energy-like Quantity changes - Mean: {:.4f}, Std: {:.4f}\".format(np.mean(energy_scores), np.std(energy_scores)))" ], "metadata": { "id": "t_aym0wlWBpg" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Visualize conservation of quantities\n", "plt.figure(figsize=(15, 5))\n", "\n", "plt.subplot(131)\n", "plt.hist(hamiltonian_scores, bins=20, color='blue', alpha=0.7)\n", "plt.title(\"Conservation of Hamiltonian\", fontsize=16)\n", "plt.xlabel(\"Standard Error\", fontsize=14)\n", "plt.ylabel(\"Frequency\", fontsize=14)\n", "\n", "plt.subplot(132)\n", "plt.hist(angular_momentum_scores, bins=20, color='green', alpha=0.7)\n", "plt.title(\"Conservation of Angular Momentum\", fontsize=16)\n", "plt.xlabel(\"Standard Error\", fontsize=14)\n", "plt.ylabel(\"Frequency\", fontsize=14)\n", "\n", "plt.subplot(133)\n", "plt.hist(energy_scores, bins=20, color='red', alpha=0.7)\n", "plt.title(\"Conservation of Energy-like Quantity\", fontsize=16)\n", "plt.xlabel(\"Standard Error\", fontsize=14)\n", "plt.ylabel(\"Frequency\", fontsize=14)\n", "\n", "plt.tight_layout()\n", "plt.savefig('conservation_laws_analysis.png', dpi=300, bbox_inches='tight')\n", "plt.show()" ], "metadata": { "id": "zOFQfeap55P7" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Calculate the overall range for x-axis\n", "all_scores = np.concatenate([hamiltonian_scores, angular_momentum_scores, energy_scores])\n", "min_score = np.min(all_scores)\n", "max_score = np.max(all_scores)\n", "\n", "# Create bins that cover the entire range\n", "bins = np.linspace(min_score, max_score, 21) # 20 bins\n", "\n", "# Compute histograms\n", "h_hist, _ = np.histogram(hamiltonian_scores, bins=bins)\n", "a_hist, _ = np.histogram(angular_momentum_scores, bins=bins)\n", "e_hist, _ = np.histogram(energy_scores, bins=bins)\n", "\n", "# Find the maximum frequency across all histograms\n", "max_freq = max(np.max(h_hist), np.max(a_hist), np.max(e_hist))\n", "\n", "plt.figure(figsize=(15, 5))\n", "\n", "plt.subplot(131)\n", "plt.hist(hamiltonian_scores, bins=bins, color='blue', alpha=0.7)\n", "plt.title(\"Conservation of Hamiltonian\", fontsize=16)\n", "plt.xlabel(\"Standard Error\", fontsize=14)\n", "plt.ylabel(\"Frequency\", fontsize=14)\n", "plt.xlim(min_score, max_score)\n", "plt.ylim(0, max_freq)\n", "\n", "plt.subplot(132)\n", "plt.hist(angular_momentum_scores, bins=bins, color='green', alpha=0.7)\n", "plt.title(\"Conservation of Angular Momentum\", fontsize=16)\n", "plt.xlabel(\"Standard Error\", fontsize=14)\n", "plt.ylabel(\"Frequency\", fontsize=14)\n", "plt.xlim(min_score, max_score)\n", "plt.ylim(0, max_freq)\n", "\n", "plt.subplot(133)\n", "plt.hist(energy_scores, bins=bins, color='red', alpha=0.7)\n", "plt.title(\"Conservation of Energy-like Quantity\", fontsize=16)\n", "plt.xlabel(\"Standard Error\", fontsize=14)\n", "plt.ylabel(\"Frequency\", fontsize=14)\n", "plt.xlim(min_score, max_score)\n", "plt.ylim(0, max_freq)\n", "\n", "plt.tight_layout()\n", "plt.savefig('conservation_laws_analysis_same_scales.png', dpi=300, bbox_inches='tight')\n", "plt.show()" ], "metadata": { "id": "9FYy8-nIZwsy" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "def calculate_trajectory_entropy(trajectory):\n", " \"\"\"Calculate the entropy of a trajectory.\"\"\"\n", " # Discretize the trajectory into bins\n", " hist, _ = np.histogram(trajectory, bins=20, density=True)\n", " return entropy(hist)\n", "\n", "def calculate_free_energy(trajectory, temperature=1.0):\n", " \"\"\"Calculate a free energy analog for a trajectory.\"\"\"\n", " # Assume energy is proportional to the squared distance from the origin\n", " energy = np.sum(trajectory**2, axis=1)\n", " entropy = calculate_trajectory_entropy(energy)\n", " return np.mean(energy) - temperature * entropy\n", "\n", "# Apply to all trajectories\n", "trajectory_entropies = [calculate_trajectory_entropy(traj) for traj in trajectories_2d]\n", "free_energies = [calculate_free_energy(traj) for traj in trajectories_2d]\n", "\n", "# Analyze the results\n", "print(\"Mean trajectory entropy:\", np.mean(trajectory_entropies))\n", "print(\"Mean free energy:\", np.mean(free_energies))\n", "\n", "# Visualize the results\n", "plt.figure(figsize=(12, 5))\n", "plt.subplot(121)\n", "plt.hist(trajectory_entropies, bins=20)\n", "plt.title(\"Distribution of Trajectory Entropies\", fontsize=16)\n", "plt.xlabel(\"Entropy\", fontsize=14)\n", "plt.ylabel(\"Frequency\", fontsize=14)\n", "\n", "plt.subplot(122)\n", "plt.hist(free_energies, bins=20)\n", "plt.title(\"Distribution of Free Energies\", fontsize=16)\n", "plt.xlabel(\"Free Energy\", fontsize=14)\n", "plt.ylabel(\"Frequency\", fontsize=14)\n", "plt.tight_layout()\n", "plt.show()" ], "metadata": { "id": "Ws8Ugh7kbj9T" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "def measure_computation_time(trajectories, num_samples):\n", " \"\"\"Measure computation time for different numbers of trajectories.\"\"\"\n", " times = []\n", " sample_sizes = range(100, num_samples, 100)\n", "\n", " for size in sample_sizes:\n", " start_time = time.time()\n", " _ = [analyze_trajectory(traj) for traj in trajectories[:size]]\n", " end_time = time.time()\n", " times.append(end_time - start_time)\n", "\n", " return sample_sizes, times\n", "\n", "def analyze_trajectory(trajectory):\n", " \"\"\"Placeholder for your trajectory analysis function.\"\"\"\n", " # Replace this with your actual analysis\n", " return calculate_hamiltonian(trajectory[:, 0], trajectory[:, 1])\n", "\n", "# Measure computation time\n", "sample_sizes, computation_times = measure_computation_time(trajectories_2d, len(trajectories_2d))\n" ], "metadata": { "id": "c4hO5bUXb_VP" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Plot the results\n", "plt.figure(figsize=(10, 6))\n", "plt.plot(sample_sizes, computation_times, 'b-')\n", "plt.title(\"Computational Complexity\", fontsize=16)\n", "plt.xlabel(\"Number of Trajectories\", fontsize=14)\n", "plt.ylabel(\"Computation Time (seconds)\", fontsize=14)\n", "plt.grid(True)\n", "plt.show()" ], "metadata": { "id": "OWw-V4apZX48" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Estimate complexity\n", "def complexity_function(x, a, b):\n", " return a * x**b\n", "\n", "popt, _ = curve_fit(complexity_function, sample_sizes, computation_times)\n", "\n", "print(f\"Estimated complexity: O(n^{popt[1]:.2f})\")" ], "metadata": { "id": "Pady9Cj8ZIdz" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "def classify_trajectory(trajectory):\n", " \"\"\"Classify a trajectory as valid or invalid based on Hamiltonian conservation.\"\"\"\n", " hamiltonian_change = np.abs(calculate_hamiltonian(trajectory[0, 0], trajectory[0, 1]) -\n", " calculate_hamiltonian(trajectory[-1, 0], trajectory[-1, 1]))\n", " return hamiltonian_change < 0.5 # Threshold for classification\n", "\n", "# Split the data\n", "X_train, X_test, y_train, y_test = train_test_split(trajectories_2d, df['is_valid'], test_size=0.2, random_state=42)\n", "\n", "# Classify test set\n", "y_pred = [classify_trajectory(traj) for traj in X_test]\n", "\n", "# Analyze errors\n", "conf_matrix = confusion_matrix(y_test, y_pred)\n", "class_report = classification_report(y_test, y_pred)\n", "\n", "print(\"Confusion Matrix:\")\n", "print(conf_matrix)\n", "print(\"\\nClassification Report:\")\n", "print(class_report)\n", "\n", "# Analyze misclassified trajectories\n", "misclassified = X_test[y_test != y_pred]\n", "misclassified_labels = y_test[y_test != y_pred]\n", "\n", "print(\"\\nAnalysis of Misclassified Trajectories:\")\n", "for i, (traj, true_label) in enumerate(zip(misclassified, misclassified_labels)):\n", " hamiltonian_change = np.abs(calculate_hamiltonian(traj[0, 0], traj[0, 1]) -\n", " calculate_hamiltonian(traj[-1, 0], traj[-1, 1]))\n", " print(f\"Trajectory {i}:\")\n", " print(f\" True label: {'Valid' if true_label else 'Invalid'}\")\n", " print(f\" Predicted: {'Valid' if classify_trajectory(traj) else 'Invalid'}\")\n", " print(f\" Hamiltonian change: {hamiltonian_change:.4f}\")\n", " print(f\" Start point: {traj[0]}\")\n", " print(f\" End point: {traj[-1]}\")\n", " print()\n", "\n", "# Visualize some misclassified trajectories\n", "plt.figure(figsize=(15, 5))\n", "for i in range(3):\n", " plt.subplot(1, 3, i+1)\n", " plt.plot(misclassified[i][:, 0], misclassified[i][:, 1], 'r-')\n", " plt.scatter(misclassified[i][0, 0], misclassified[i][0, 1], c='g', label='Start')\n", " plt.scatter(misclassified[i][-1, 0], misclassified[i][-1, 1], c='b', label='End')\n", " plt.title(f\"Misclassified Trajectory {i+1}\", fontsize=16)\n", " plt.xlabel(\"PC1\", fontsize=14)\n", " plt.ylabel(\"PC2\", fontsize=14)\n", " plt.legend()\n", "plt.tight_layout()\n", "plt.show()" ], "metadata": { "id": "p9PhYaNpcJJd" }, "execution_count": null, "outputs": [] } ] }