Skip to content

Commit

Permalink
fix tests; add sigma 0.05 runs; add analysis notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
misko committed Nov 6, 2024
1 parent 086432a commit ed9064e
Show file tree
Hide file tree
Showing 4 changed files with 404 additions and 27 deletions.
292 changes: 265 additions & 27 deletions filter_reports/analyze_PF_report.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,38 +6,213 @@
"metadata": {},
"outputs": [],
"source": [
"import seaborn as sns\n",
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"\n",
"# Load the CSV file into a DataFrame\n",
"file_path = \"report_nov6_fn_run_PF_single_theta_dual_radio_NN.csv\"\n",
"data = pd.read_csv(file_path)\n",
"\n",
"# Plot the heatmap\n",
"for theta_err in data[\"theta_err\"].unique():\n",
" fig, axs = plt.subplots(1, 2, figsize=(12, 6))\n",
" for movement, ax in [(\"bounce\", axs[0]), (\"circle\", axs[1])]:\n",
" df = data[\n",
" (data[\"movement\"] == movement) & (data[\"theta_err\"] == theta_err)\n",
" ].copy()\n",
"\n",
" # Convert 'N' to integer for sorting\n",
" df[\"N\"] = df[\"N\"].astype(int)\n",
"\n",
" # Sort by 'N' as integers\n",
" df = df.sort_values(by=\"N\")\n",
"\n",
" # Convert 'N' back to string and set as a categorical type with ordered categories\n",
" df[\"N\"] = df[\"N\"].astype(str)\n",
" df[\"N\"] = pd.Categorical(\n",
" df[\"N\"], categories=sorted(df[\"N\"].unique(), key=int), ordered=True\n",
" )\n",
"\n",
" df[\"theta_dot_err\"] = df[\"theta_dot_err\"].astype(str)\n",
"\n",
" # Average other fields over 'mse_craft_theta' with observed=True to avoid the warning\n",
" heatmap_data = (\n",
" df.groupby([\"N\", \"theta_dot_err\"], observed=True)\n",
" .agg({\"mse_craft_theta\": \"mean\"})\n",
" .reset_index()\n",
" )\n",
"\n",
" # Pivot the data for the heatmap\n",
" heatmap_pivot = heatmap_data.pivot(\n",
" index=\"theta_dot_err\", columns=\"N\", values=\"mse_craft_theta\"\n",
" )\n",
"\n",
" sns.heatmap(\n",
" heatmap_pivot,\n",
" annot=True,\n",
" fmt=\".2f\",\n",
" cmap=\"YlGnBu\",\n",
" cbar_kws={\"label\": \"Mean mse_craft_theta\"},\n",
" ax=ax,\n",
" )\n",
" ax.set_title(f\"{movement} theta_err:{theta_err}\")\n",
" ax.set_xlabel(\"N\")\n",
" ax.set_ylabel(\"theta_dot_err\")\n",
" fig.set_tight_layout(True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"file_path = \"report_nov6_fn_run_EKF_single_theta_dual_radio.csv\"\n",
"data = pd.read_csv(file_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import seaborn as sns\n",
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"\n",
"# Load the CSV file into a DataFrame\n",
"file_path = \"report_nov6_fn_run_EKF_single_theta_dual_radio.csv\"\n",
"data = pd.read_csv(file_path)\n",
"\n",
"# Plot the heatmap\n",
"for p in data[\"p\"].unique():\n",
" fig, axs = plt.subplots(1, 2, figsize=(12, 6))\n",
" for movement, ax in [(\"bounce\", axs[0]), (\"circle\", axs[1])]:\n",
" df = data[\n",
" (data[\"movement\"] == movement)\n",
" & (data[\"p\"] == p)\n",
" & (data[\"dynamic_R\"] == 0.0)\n",
" ].copy()\n",
"\n",
" # Convert 'N' to integer for sorting\n",
" df[\"noise_std\"] = df[\"noise_std\"].astype(float)\n",
" df[\"phi_std\"] = df[\"phi_std\"].astype(float)\n",
"\n",
" # Sort by 'N' as integers\n",
" df = df.sort_values(by=\"noise_std\")\n",
" df = df.sort_values(by=\"phi_std\")\n",
"\n",
" # Convert 'N' back to string and set as a categorical type with ordered categories\n",
" df[\"noise_std\"] = df[\"noise_std\"].astype(str)\n",
" df[\"noise_std\"] = pd.Categorical(\n",
" df[\"noise_std\"],\n",
" categories=sorted(df[\"noise_std\"].unique(), key=float),\n",
" ordered=True,\n",
" )\n",
"\n",
" df[\"phi_std\"] = pd.Categorical(\n",
" df[\"phi_std\"],\n",
" categories=sorted(df[\"phi_std\"].unique(), key=float),\n",
" ordered=True,\n",
" )\n",
"\n",
" # Average other fields over 'mse_craft_theta' with observed=True to avoid the warning\n",
" heatmap_data = (\n",
" df.groupby([\"noise_std\", \"phi_std\"], observed=True)\n",
" .agg({\"mse_craft_theta\": \"mean\"})\n",
" .reset_index()\n",
" )\n",
"\n",
" # Pivot the data for the heatmap\n",
" heatmap_pivot = heatmap_data.pivot(\n",
" index=\"phi_std\", columns=\"noise_std\", values=\"mse_craft_theta\"\n",
" )\n",
"\n",
" sns.heatmap(\n",
" heatmap_pivot,\n",
" annot=True,\n",
" fmt=\".2f\",\n",
" cmap=\"YlGnBu\",\n",
" cbar_kws={\"label\": \"Mean mse_craft_theta\"},\n",
" ax=ax,\n",
" )\n",
" ax.set_title(f\"{movement} p:{p}\")\n",
" ax.set_xlabel(\"noise_std\")\n",
" ax.set_ylabel(\"phi_std\")\n",
" fig.set_tight_layout(True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import seaborn as sns\n",
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"\n",
"# Load the CSV file into a DataFrame\n",
"file_path = \"report_nov4_fn_run_PF_single_theta_dual_radio_NN.csv\"\n",
"file_path = \"report_nov6_fn_run_PF_single_theta_single_radio_NN.csv\"\n",
"data = pd.read_csv(file_path)\n",
"\n",
"# Separate the data by movement type\n",
"for movement in [\"bounce\", \"circle\"]:\n",
" movement_data = data[data[\"movement\"] == movement]\n",
"\n",
" # Scatter plot for 'bounce' movement\n",
" plt.figure(figsize=(10, 6))\n",
" for n in movement_data[\"N\"].unique():\n",
" subset = movement_data[movement_data[\"N\"] == n]\n",
" plt.scatter(\n",
" subset[\"theta_dot_err\"],\n",
" subset[\"mse_craft_theta\"],\n",
" label=f\"N={n}\",\n",
" alpha=0.7,\n",
" )\n",
"\n",
" plt.xscale(\"log\")\n",
" plt.xlim(\n",
" movement_data[\"theta_dot_err\"].min() / 2,\n",
" movement_data[\"theta_dot_err\"].max() * 2,\n",
" )\n",
" plt.xlabel(\"theta_dot_err (log scale)\")\n",
" plt.ylabel(\"mse_craft_theta\")\n",
" plt.title(\"Bounce Movement: theta_dot_err vs mse_craft_theta (Scatter Plot)\")\n",
" plt.legend()\n",
" plt.show()"
"# Plot the heatmap\n",
"for theta_err in data[\"theta_err\"].unique():\n",
" fig, axs = plt.subplots(1, 2, figsize=(12, 6))\n",
" for movement, ax in [(\"bounce\", axs[0]), (\"circle\", axs[1])]:\n",
" df = data[\n",
" (data[\"movement\"] == movement) & (data[\"theta_err\"] == theta_err)\n",
" ].copy()\n",
"\n",
" # Convert 'N' to integer for sorting\n",
" df[\"N\"] = df[\"N\"].astype(int)\n",
"\n",
" # Sort by 'N' as integers\n",
" df = df.sort_values(by=\"N\")\n",
"\n",
" # Convert 'N' back to string and set as a categorical type with ordered categories\n",
" df[\"N\"] = df[\"N\"].astype(str)\n",
" df[\"N\"] = pd.Categorical(\n",
" df[\"N\"], categories=sorted(df[\"N\"].unique(), key=int), ordered=True\n",
" )\n",
"\n",
" df[\"theta_dot_err\"] = df[\"theta_dot_err\"].astype(str)\n",
"\n",
" # Average other fields over 'mse_craft_theta' with observed=True to avoid the warning\n",
" heatmap_data = (\n",
" df.groupby([\"N\", \"theta_dot_err\"], observed=True)\n",
" .agg({\"mse_single_radio_theta\": \"mean\"})\n",
" .reset_index()\n",
" )\n",
"\n",
" # Pivot the data for the heatmap\n",
" heatmap_pivot = heatmap_data.pivot(\n",
" index=\"theta_dot_err\", columns=\"N\", values=\"mse_single_radio_theta\"\n",
" )\n",
"\n",
" sns.heatmap(\n",
" heatmap_pivot,\n",
" annot=True,\n",
" fmt=\".2f\",\n",
" cmap=\"YlGnBu\",\n",
" cbar_kws={\"label\": \"Mean mse_single_radio_theta\"},\n",
" ax=ax,\n",
" )\n",
" ax.set_title(f\"{movement} theta_err:{theta_err}\")\n",
" ax.set_xlabel(\"N\")\n",
" ax.set_ylabel(\"theta_dot_err\")\n",
" fig.set_tight_layout(True)"
]
},
{
Expand All @@ -46,7 +221,70 @@
"metadata": {},
"outputs": [],
"source": [
"bounce_data[\"theta_dot_err\"].min()"
"import seaborn as sns\n",
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"\n",
"# Load the CSV file into a DataFrame\n",
"file_path = \"report_nov6_fn_run_EKF_single_theta_single_radio.csv\"\n",
"data = pd.read_csv(file_path)\n",
"\n",
"# Plot the heatmap\n",
"for p in data[\"p\"].unique():\n",
" fig, axs = plt.subplots(1, 2, figsize=(12, 6))\n",
" for movement, ax in [(\"bounce\", axs[0]), (\"circle\", axs[1])]:\n",
" df = data[\n",
" (data[\"movement\"] == movement)\n",
" & (data[\"p\"] == p)\n",
" & (data[\"dynamic_R\"] == 0.0)\n",
" ].copy()\n",
"\n",
" # Convert 'N' to integer for sorting\n",
" df[\"noise_std\"] = df[\"noise_std\"].astype(float)\n",
" df[\"phi_std\"] = df[\"phi_std\"].astype(float)\n",
"\n",
" # Sort by 'N' as integers\n",
" df = df.sort_values(by=\"noise_std\")\n",
" df = df.sort_values(by=\"phi_std\")\n",
"\n",
" # Convert 'N' back to string and set as a categorical type with ordered categories\n",
" df[\"noise_std\"] = df[\"noise_std\"].astype(str)\n",
" df[\"noise_std\"] = pd.Categorical(\n",
" df[\"noise_std\"],\n",
" categories=sorted(df[\"noise_std\"].unique(), key=float),\n",
" ordered=True,\n",
" )\n",
"\n",
" df[\"phi_std\"] = pd.Categorical(\n",
" df[\"phi_std\"],\n",
" categories=sorted(df[\"phi_std\"].unique(), key=float),\n",
" ordered=True,\n",
" )\n",
"\n",
" # Average other fields over 'mse_craft_theta' with observed=True to avoid the warning\n",
" heatmap_data = (\n",
" df.groupby([\"noise_std\", \"phi_std\"], observed=True)\n",
" .agg({\"mse_single_radio_theta\": \"mean\"})\n",
" .reset_index()\n",
" )\n",
"\n",
" # Pivot the data for the heatmap\n",
" heatmap_pivot = heatmap_data.pivot(\n",
" index=\"phi_std\", columns=\"noise_std\", values=\"mse_single_radio_theta\"\n",
" )\n",
"\n",
" sns.heatmap(\n",
" heatmap_pivot,\n",
" annot=True,\n",
" fmt=\".2f\",\n",
" cmap=\"YlGnBu\",\n",
" cbar_kws={\"label\": \"Mean mse_single_radio_theta\"},\n",
" ax=ax,\n",
" )\n",
" ax.set_title(f\"{movement} p:{p}\")\n",
" ax.set_xlabel(\"noise_std\")\n",
" ax.set_ylabel(\"phi_std\")\n",
" fig.set_tight_layout(True)"
]
},
{
Expand Down
74 changes: 74 additions & 0 deletions latest_configs/paired_sigma0p05.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
datasets:
batch_size: 256
empirical_data_fn: /home/mouse9911/gits/spf/empirical_dists/full.pkl
empirical_individual_radio: false
empirical_symmetry: true
flip: false
double_flip: True
precompute_cache: /home/mouse9911/precompute_cache_chunk16_sept
scatter: continuous
scatter_k: 21
shuffle: true
sigma: 0.05
skip_qc: true
snapshots_adjacent_stride: 1
train_snapshots_per_session: 1
val_snapshots_per_session: 1
random_snapshot_size: False
snapshots_stride: 1
train_paths:
- /mnt/4tb_ssd/nosig_data/train.txt
val_paths:
- /mnt/4tb_ssd/nosig_data/val.txt
val_holdout_fraction: 0.2
val_subsample_fraction: 0.2
workers: 20
global:
beamformer_input: true
empirical_input: true
n_radios: 2
nthetas: 65
phase_input: true
rx_spacing_input: true
seed: 10
logger:
log_every: 100
name: wandb
plot_every: 15000
project: 2024_nov2_single_paired_multi
model:
block: true
bn: true
depth: 4
detach: true
dropout: 0.0
hidden: 1024
load_single: true
name: pairedbeamformer
norm: layer
single:
block: true
bn: true
depth: 4
detach: true
dropout: 0.0
hidden: 1024
input_dropout: 0.3
norm: layer
optim:
amp: true
checkpoint: /home/mouse9911/gits/spf/nov4_checkpoints/single_checkpoints_inputdo0p3_sigma0p05/best.pth
checkpoint_every: 5000
device: cuda
direct_loss: false
dtype: torch.float32
epochs: 60
head_start: 0
learning_rate: 0.0002
loss: mse
output: /home/mouse9911/gits/spf/nov4_checkpoints/paired_checkpoints_inputdo0p3_sigma0p05
resume_step: 0
save_on: val_paired_loss
scheduler_step: 6
val_every: 10000
weight_decay: 0.0
Loading

0 comments on commit ed9064e

Please sign in to comment.