Utilities

Utilities#

NeuralFieldManifold.utils.train_loop(model, train_loader, val_loader, n_epochs=100, lr=0.001, lambda_p=10.0, lambda_ar=1.0, lambda_energy=0.1, lambda_smooth=0.05, lambda_order=0.0, P0=0.5, W=40, p_max=6, device='cuda')[source]#

Train a model with the composite PINN-style loss.

Optimises the weighted sum of five loss terms (classification, AR reconstruction, energy constraint, smoothness, and order regularisation) using Adam with gradient clipping and ReduceLROnPlateau scheduling.

Parameters:
  • model (torch.nn.Module) – Model to train. Must follow the (coeffs, p_logits, p_hard, x_hat) forward convention.

  • train_loader (DataLoader) – Training data yielding (x_batch, p_batch) tuples.

  • val_loader (DataLoader) – Validation data yielding (x_batch, p_batch) tuples.

  • n_epochs (int, optional) – Number of training epochs. Default is 100.

  • lr (float, optional) – Initial learning rate for Adam. Default is 1e-3.

  • lambda_p (float, optional) – Weight for the AR-order cross-entropy loss. Default is 10.0.

  • lambda_ar (float, optional) – Weight for the signal reconstruction MSE loss. Default is 1.0.

  • lambda_energy (float, optional) – Weight for the energy constraint loss. Default is 0.1.

  • lambda_smooth (float, optional) – Weight for the coefficient smoothness loss. Default is 0.05.

  • lambda_order (float, optional) – Weight for the order regulariser. Default is 0.0 (disabled).

  • P0 (float, optional) – Target power level for the energy loss. Default is 0.5.

  • W (int, optional) – Sliding-window size for the energy loss. Default is 40.

  • p_max (int, optional) – Maximum AR order (controls how many leading time steps are excluded in the AR loss). Default is 6.

  • device (str or torch.device, optional) – Compute device. Default is 'cuda'.

Returns:

history – Dictionary mapping metric names (e.g. 'train_loss', 'val_p_acc') to lists of per-epoch values.

Return type:

dict

NeuralFieldManifold.utils.bench_loop(model, X_val, coef_val, p_val, device, p_min=2, p_max=6, p_max_order=6, batch_size=32)[source]#

Evaluate a trained model on validation data and return benchmark metrics.

Runs the model in inference mode over the validation set and computes coefficient MSE, signal reconstruction MSE, per-order accuracy, and aggregate order-prediction statistics.

Parameters:
  • model (torch.nn.Module) – Trained model with the standard (coeffs, p_logits, p_hard, x_hat) forward signature.

  • X_val (array-like or torch.Tensor) – Validation time series of shape (N, T).

  • coef_val (np.ndarray) – Ground-truth AR coefficients of shape (N, T, max_ar_order).

  • p_val (array-like or torch.Tensor) – Ground-truth 0-indexed order class labels of shape (N,).

  • device (torch.device) – Device on which to run inference.

  • p_min (int, optional) – Minimum AR order corresponding to class index 0. Default is 2.

  • p_max (int, optional) – Maximum AR order. Default is 6.

  • p_max_order (int, optional) – Number of leading time steps to skip when computing signal MSE. Default is 6.

  • batch_size (int, optional) – Inference batch size. Default is 32.

Returns:

results – Dictionary with keys 'coeff_mse', 'signal_mse', 'p_mae', 'p_mape', per-order accuracies ('p2_acc', …, 'p6_acc'), and overall 'p_acc'.

Return type:

dict

NeuralFieldManifold.utils.bicoherence(x: ndarray, fs: float, nperseg: int = 256, noverlap: int | None = None, fmax: float | None = None) tuple[ndarray, ndarray][source]#

Compute bicoherence for detecting quadratic phase coupling.

Bicoherence values near 1 indicate strong phase-locking between frequency components f1, f2, and f1+f2.

Parameters:
  • x (np.ndarray) – Input signal (1D).

  • fs (float) – Sampling frequency in Hz.

  • nperseg (int) – Segment length for FFT. Default is 256.

  • noverlap (int, optional) – Overlap between segments. Default is nperseg // 2.

  • fmax (float, optional) – Maximum frequency to compute. Default is fs / 4.

Returns:

  • freqs (np.ndarray) – Frequency axis.

  • bic (np.ndarray) – Bicoherence matrix (f1 x f2).

Example

>>> freqs, bic_matrix = bicoherence(signal, fs=500, nperseg=512, fmax=50)