Manifold reconstruction of mice hippocampus

Manifold reconstruction of mice hippocampus#

import numpy as np
import mne
import os
from scipy import signal
from scipy.signal import hilbert, butter, filtfilt, spectrogram
from scipy.ndimage import gaussian_filter, gaussian_filter1d
from matplotlib.ticker import ScalarFormatter
from statsmodels.tsa.stattools import acf
import matplotlib.pyplot as plt

from NeuralFieldManifold.models import AR
from NeuralFieldManifold.embedders import embed
from NeuralFieldManifold.pub_utils import *

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

set_pub_style()
def bandpass_filter(x, Fs, low, high, order=4):
    wn = [low/(Fs/2), high/(Fs/2)]
    sos = signal.butter(order, wn, btype='bandpass', output='sos')
    return signal.sosfiltfilt(sos, x)

def envelope_normalize(x, fs, fband=(15, 30), env_lp_hz=5.0, eps=1e-6):
    b, a = signal.butter(4, np.array(fband)/(fs/2), btype='bandpass')
    x_bp = signal.filtfilt(b, a, x)
    z = hilbert(x_bp)
    A = np.abs(z)
    if env_lp_hz is not None:
        b_lp, a_lp = signal.butter(2, env_lp_hz/(fs/2), btype='low')
        A = signal.filtfilt(b_lp, a_lp, A)
    x_flat = x_bp / (A + eps)
    x_flat *= np.median(A)
    return x_flat, x_bp, A

def process_signal(LFP, chan, Fs, time_window, 
                   notch_freq=50, notch_Q=30,
                   bandpass_low=80, bandpass_high=140,
                   fband=(1, 80), env_lp_hz=3.0,
                   target_Fs=500):                                                                                                          
    """Process LFP signal with notch filter, bandpass, envelope normalization, and downsampling."""
    filtered = signal.filtfilt(*signal.iirnotch(notch_freq, notch_Q, Fs), LFP[time_window, chan])
    x = np.asarray(filtered).ravel()
    xs_raw = bandpass_filter(x, Fs, low=bandpass_low, high=bandpass_high)
    
    # Downsample if needed (before envelope normalization for efficiency)
    if Fs > target_Fs:
        r = int(np.floor(Fs / target_Fs))
        xs_raw = signal.decimate(xs_raw, r, ftype='iir', zero_phase=True).astype(float)
        Fs_out = Fs / r
    else:
        Fs_out = Fs
    
    xs, _, _ = envelope_normalize(xs_raw, Fs_out, fband=fband, env_lp_hz=env_lp_hz)
    return xs, xs_raw, Fs_out

def fit_ar_model(xs, Fs, p=4, train_frac=0.9, refresh_every=5):
    X, y = AR.lag_matrix(xs, p)
    ntr = int(train_frac * len(y))
    X_tr, y_tr = X[:ntr], y[:ntr]
    w = AR.fit(X_tr, y_tr)
    N_te = len(y) - ntr
    t0 = p + ntr
    true_test = xs[t0 : t0 + N_te]
    yhat_ctrl = AR.hybrid_predict(xs, w, p, start_idx=t0, n_steps=N_te, refresh_every=refresh_every)
    full_pred = AR.hybrid_predict(xs, w, p, start_idx=p, n_steps=len(xs)-p, refresh_every=refresh_every)
    xs_pred_full = np.concatenate([xs[:p], full_pred])
    mse, mae, corr = AR.metrics(true_test, yhat_ctrl)
    return {
        'w': w, 'p': p, 'ntr': ntr, 'y': y, 'true_test': true_test,
        'yhat_ctrl': yhat_ctrl, 'xs_pred_full': xs_pred_full, 't0': t0,
        'mse': mse, 'mae': mae, 'corr': corr, 'refresh_every': refresh_every
    }
filename = '/Users/novak/Downloads/eke003_spont_250818_175844.npy'
channel_id = 1
raw_lfp = np.load(filename)[channel_id]
fs_original = 20000

print(f"Loaded LFP: {raw_lfp.shape} samples at {fs_original} Hz")
Loaded LFP: (10841760,) samples at 20000 Hz
snippet_sec = 5
time_window = np.arange(int(snippet_sec * fs_original))

# Reshape raw_lfp to (samples, channels) format for process_signal
LFP = raw_lfp.reshape(-1, 1)

xs, xs_raw, Fs = process_signal(LFP, chan=0, Fs=fs_original, time_window=time_window,
                                 bandpass_low=1.0, bandpass_high=50.0,
                                 fband=(1, 80), env_lp_hz=3.0,
                                 target_Fs=500)

time = np.arange(len(xs)) / Fs
print(f"Processed signal: {len(xs)} samples at {Fs} Hz")
Processed signal: 100000 samples at 20000 Hz
# Sweep AR orders for elbow plot
p_list = list(range(2, 7))
mse_te_list = []
for p in p_list:
    X, y_temp = AR.lag_matrix(xs, p)
    ntr_temp = int(0.1 * len(y_temp))
    X_tr, y_tr = X[:ntr_temp], y_temp[:ntr_temp]
    X_te, y_te = X[ntr_temp:], y_temp[ntr_temp:]
    w_temp = AR.fit(X_tr, y_tr)
    yhat_te = AR.predict_from_params(X_te, w_temp)
    mse_te, _, _ = AR.metrics(y_te, yhat_te)
    mse_te_list.append(mse_te)

best_p_by_mse = p_list[int(np.argmin(mse_te_list))]

def plot_mse_elbow(p_list, mse_list, best_p, color_main="black", color_alt="darkred"):
    fig, ax = plt.subplots(figsize=(4, 3))
    ax.plot(p_list, mse_list, marker='o', color=color_main, lw=1.5, ms=5)
    ax.axvline(best_p, linestyle='--', color=color_alt, linewidth=1)
    ax.text(best_p, ax.get_ylim()[1]*0.95, f"best p={best_p}", ha='center', va='top', color=color_alt, fontsize=11)
    sf = ScalarFormatter(useMathText=True)
    sf.set_powerlimits((-3, 3))
    ax.yaxis.set_major_formatter(sf)
    ax.ticklabel_format(axis='y', style='sci', scilimits=(-3, 3))
    ax.yaxis.get_offset_text().set_size(10)
    prettify(ax, title="AR(p) elbow: Test MSE vs p", xlabel="Order p", ylabel="Test MSE")
    plt.tight_layout()
    plt.show()

best_p_by_mse = 3

plot_mse_elbow(p_list, mse_te_list, best_p_by_mse)
ar_results = fit_ar_model(xs, Fs, p=3, train_frac=0.1, refresh_every=10)

w = ar_results['w']
p_use = ar_results['p']
ntr = ar_results['ntr']
y = ar_results['y']
true_test = ar_results['true_test']
yhat_ctrl = ar_results['yhat_ctrl']
xs_pred_full = ar_results['xs_pred_full']
REFRESH_EVERY = ar_results['refresh_every']

print(f"AR({p_use}): MSE={ar_results['mse']:.6f}, corr={ar_results['corr']:.3f}")
def plot_ar_coefficients(w, p, color_main="black", color_alt="darkred"):
    fig, ax = plt.subplots(figsize=(2, 3))
    ax.bar(np.arange(1, p+1), w[1:], color=color_alt, edgecolor=color_main, alpha=0.8)
    prettify(ax, title=f"AR({p}) coefficients (OLS)", xlabel="Lag i", ylabel="aᵢ")
    plt.tight_layout()
    plt.show()

plot_ar_coefficients(w, p_use)
../_images/02aa6fcb40aa99b00a039960f6287aaf4ab28d108001868dbff9c1bc9fba2125.png
def plot_prediction(y_true, y_pred, p, refresh, color_main="black", color_alt="darkred"):
    fig, ax = plt.subplots(figsize=(6, 3))
    ax.plot(y_true + 1, color=color_main, lw=2, label="Truth")
    ax.plot(y_pred, color=color_alt, lw=2, label=f"AR({p}) pred (refresh={refresh})")
    prettify(ax, title="AR(p) controlled 1→k-step prediction — test window",
             xlabel="Time (test index)", ylabel="x(t)", add_legend=True)
    plt.tight_layout()
    plt.show()

plot_prediction(y[ntr:], yhat_ctrl, p_use, REFRESH_EVERY)
../_images/d1414d7609cbc45562583e2645f4108baf4746651c2efc2ef151b08959377413.png
def plot_residual_acf(resid, nlags=60, color_main="black", color_alt="darkred"):
    acf_vals = acf(resid, nlags=nlags)
    bound = 1.96 / np.sqrt(max(len(resid), 1))
    fig, ax = plt.subplots(figsize=(4, 3))
    markerline, stemlines, baseline = ax.stem(range(len(acf_vals)), acf_vals, basefmt=" ")
    plt.setp(stemlines, color=color_main, linewidth=1.2)
    plt.setp(markerline, marker='o', markersize=5, markeredgecolor=color_main, markerfacecolor=color_alt)
    ax.axhline(bound, linestyle='--', color=color_alt, linewidth=1)
    ax.axhline(-bound, linestyle='--', color=color_alt, linewidth=1)
    plt.tight_layout()
    plt.show()

resid_te = y[ntr:] - yhat_ctrl
plot_residual_acf(resid_te)
../_images/8d640a5c01c719598017ee923067f2301632e507c027f8175e659802e0797c67.png
def plot_scatter(y_true, y_pred, refresh, color_main="black", color_alt="darkred"):
    fig, ax = plt.subplots(figsize=(3, 3))
    ax.scatter(y_true, y_pred, c=color_alt, alpha=0.35, s=10, edgecolors="none")
    vmin, vmax = min(np.min(y_true), np.min(y_pred)), max(np.max(y_true), np.max(y_pred))
    ax.plot([vmin, vmax], [vmin, vmax], color=color_main, lw=1)
    yt = y_true - np.mean(y_true)
    yp = y_pred - np.mean(y_pred)
    corr = float((yt @ yp) / np.sqrt((yt @ yt) * (yp @ yp) + 1e-12))
    ss_res = np.sum((y_true - y_pred)**2)
    ss_tot = np.sum((y_true - np.mean(y_true))**2)
    r2 = 1 - ss_res / (ss_tot + 1e-12)
    ax.text(0.05, 0.95, f"$R^2$ = {r2:.3f}\n$\\rho$ = {corr:.3f}",
            transform=ax.transAxes, ha='left', va='top', fontsize=11,
            bbox=dict(facecolor="white", edgecolor="none", alpha=0.7, boxstyle="round"))
    ax.set_aspect('equal', adjustable='box')
    prettify(ax, title=f"Test scatter: y_pred vs y_true (refresh={refresh})", xlabel="y_true", ylabel="y_pred")
    plt.tight_layout()
    plt.show()

plot_scatter(y[ntr:], yhat_ctrl, REFRESH_EVERY)
../_images/f1c3bdcff77828150740eb7589832d0f26819e78da5009a7719ca08c33cc67f4.png
def plot_state_space(X3_true, X3_pred, p, tau, refresh, color_true="black", color_pred="darkred"):
    fig = plt.figure(figsize=(8, 6))
    ax = fig.add_subplot(111, projection='3d')
    ax.plot(X3_true[:,0], X3_true[:,1], X3_true[:,2], label="TRUE", color=color_true, alpha=0.85)
    ax.plot(X3_pred[:,0], X3_pred[:,1], X3_pred[:,2], label=f"AR({p}) hybrid (refresh={refresh})", color=color_pred, alpha=0.85)
    ax.set_title(f"State space — TRUE vs AR hybrid (τ={tau}, refresh_every={refresh})")
    ax.set_xlabel("x(t)")
    ax.set_ylabel(f"x(t-{tau})")
    ax.set_zlabel(f"x(t-{2*tau})")
    ax.legend(frameon=False)
    plt.tight_layout()
    plt.show()

tau = 10
hy = yhat_ctrl - np.mean(yhat_ctrl)
hy = hy * (np.std(true_test) / (np.std(hy) + 1e-12)) + np.mean(true_test)

X3_true = embed(true_test, 3, tau)
X3_hyb = embed(hy, 3, tau)

plot_state_space(X3_true, X3_hyb, p_use, tau, REFRESH_EVERY)
../_images/895ecf7989a2a985b3ebd8c7262cf5b4bd1d1ec111aa87063e083b6421ab1536.png
xs_pred_full
array([ 33.07407173,  30.0745926 ,  27.37802688, ..., -28.04769689,
       -30.74224474, -34.19860147], shape=(2500,))
def plot_state_space_windows_comparison(xs, xs_pred, Fs, tau, embed_dim=3, 
                                         windows=[(0, 3), (1, 2), (2, 3), (3, 4)]):
    fig, axes = plt.subplots(1, 4, figsize=(len(windows)*4, 4), subplot_kw={'projection': '3d'})
    
    for ax, (t_start, t_end) in zip(axes, windows):
        idx_start = int(t_start * Fs)
        idx_end = int(t_end * Fs)
        
        # True signal
        segment_true = xs[idx_start:idx_end]
        X3_true = embed(segment_true, embed_dim, tau)
        
        # Predicted signal (variance-matched)
        segment_pred = xs_pred[idx_start:idx_end]
        seg_pred_matched = segment_pred - np.mean(segment_pred)
        seg_pred_matched = seg_pred_matched * (np.std(segment_true) / (np.std(seg_pred_matched) + 1e-12))
        seg_pred_matched = seg_pred_matched + np.mean(segment_true)
        X3_pred = embed(seg_pred_matched, embed_dim, tau)
        
        ax.plot(X3_true[:, 0], X3_true[:, 1], X3_true[:, 2], color='black', alpha=0.85, lw=0.8, label='True')
        ax.plot(X3_pred[:, 0], X3_pred[:, 1], X3_pred[:, 2], color='darkred', alpha=0.85, lw=0.8, label='AR(4) pred')
        ax.set_xlabel("x(t)")
        ax.set_ylabel(f"x(t-{tau})")
        ax.set_zlabel(f"x(t-{2*tau})")
        ax.text2D(0.5, -0.15, f"{t_start}-{t_end} s", transform=ax.transAxes, ha='center', fontsize=11)
    
    axes[0].legend(frameon=False, loc='upper left')
    # fig.suptitle(f"State space: TRUE vs AR({p_use}) (τ={tau}, refresh={REFRESH_EVERY})", y=0.02)
    plt.tight_layout()
    plt.show()

# Build full prediction for visualization
full_pred = AR.hybrid_predict(xs, w, p_use, start_idx=p_use, n_steps=len(xs)-p_use, refresh_every=REFRESH_EVERY)
xs_pred_full = np.concatenate([xs[:p_use], full_pred])

plot_state_space_windows_comparison(xs, xs_pred_full, Fs, tau)
../_images/dd2705d365c72348d2085404f5d434b5a6557d355ad2b40aa876074ae9dd2366.png
Fs
500.0
def plot_state_space_tau_sweep(xs, xs_pred, Fs, t_start, t_end, taus=[3, 30, 60, 90], embed_dim=3):
    fig, axes = plt.subplots(1, 4, figsize=(20, 5), subplot_kw={'projection': '3d'})
    
    idx_start = int(t_start * Fs)
    idx_end = int(t_end * Fs)
    segment_true = xs[idx_start:idx_end]
    segment_pred = xs_pred[idx_start:idx_end]
    
    # Variance-match prediction
    seg_pred_matched = segment_pred - np.mean(segment_pred)
    seg_pred_matched = seg_pred_matched * (np.std(segment_true) / (np.std(seg_pred_matched) + 1e-12))
    seg_pred_matched = seg_pred_matched + np.mean(segment_true)
    
    for ax, tau_val in zip(axes, taus):
        X3_true = embed(segment_true, embed_dim, tau_val)
        X3_pred = embed(seg_pred_matched, embed_dim, tau_val)
        
        ax.plot(X3_true[:, 0], X3_true[:, 1], X3_true[:, 2], color='black', alpha=0.85, lw=0.8, label='True')
        ax.plot(X3_pred[:, 0], X3_pred[:, 1], X3_pred[:, 2], color='darkred', alpha=0.85, lw=0.8, label=f'AR({p_use}) pred')
        ax.set_xlabel("x(t)", labelpad=10)
        ax.set_ylabel(f"x(t-{tau_val})", labelpad=10)
        ax.set_zlabel(f"x(t-{2*tau_val})", labelpad=10)
        ax.text2D(0.5, -0.1, f"τ = {tau_val}", transform=ax.transAxes, ha='center', fontsize=11)
    
    axes[0].legend(frameon=False, loc='upper left')
    fig.subplots_adjust(left=0.02, right=0.98, bottom=0.1, top=0.95, wspace=0.15)
    plt.show()

# Use valid time window within 5 second signal
plot_state_space_tau_sweep(xs, xs_pred_full, Fs, t_start=2, t_end=4)
../_images/89ee2aabd07c4d939e22ad9b8d54b27a3628da08ef8901b4efda4b0c74969769.png
def sweep_channels_plot(raw_data, channels, time_windows, taus, fs_orig,
                        p=3, train_frac=0.9, refresh_every=5,
                        bandpass_low=1.0, bandpass_high=50.0, 
                        fband=(1, 80), env_lp_hz=3.0, embed_dim=3,
                        target_Fs=500):
    """
    Process multiple channels from mouse data and plot state space comparison.
    
    Parameters
    ----------
    raw_data : ndarray
        Raw LFP data with shape (channels, samples)
    channels : list of int
        Channel indices to process
    time_windows : list of tuple
        (start_sec, end_sec) for each channel's plot window
    taus : list of int
        Embedding delay for each channel
    fs_orig : float
        Original sampling frequency
    target_Fs : float
        Target sampling frequency after downsampling
    """
    n = len(channels)
    fig, axes = plt.subplots(1, n, figsize=(n*5, 5), subplot_kw={'projection': '3d'})
    if n == 1:
        axes = [axes]
    
    # Convert to (samples, channels) format
    LFP = raw_data.T
    
    for ax, chan, (t_start, t_end), tau_val in zip(axes, channels, time_windows, taus):
        # Create time window indices
        time_window = np.arange(int(t_end * fs_orig))
        
        # Process signal (now returns Fs_out)
        xs_ch, _, Fs_ch = process_signal(LFP, chan, fs_orig, time_window,
                                          bandpass_low=bandpass_low, bandpass_high=bandpass_high,
                                          fband=fband, env_lp_hz=env_lp_hz,
                                          target_Fs=target_Fs)
        
        # Fit AR model
        ar_res = fit_ar_model(xs_ch, Fs_ch, p=p, train_frac=train_frac, refresh_every=refresh_every)
        w_ch, xs_pred_ch = ar_res['w'], ar_res['xs_pred_full']
        
        # Extract window (using downsampled Fs)
        idx_start, idx_end = int(t_start * Fs_ch), int(t_end * Fs_ch)
        segment_true = xs_ch[idx_start:idx_end]
        segment_pred = xs_pred_ch[idx_start:idx_end]
        
        # Variance-match prediction
        seg_pred_matched = segment_pred - np.mean(segment_pred)
        seg_pred_matched = seg_pred_matched * (np.std(segment_true) / (np.std(seg_pred_matched) + 1e-12))
        seg_pred_matched = seg_pred_matched + np.mean(segment_true)
        
        # Embed
        X3_true = embed(segment_true, embed_dim, tau_val)
        X3_pred = embed(seg_pred_matched, embed_dim, tau_val)
        
        # Plot
        ax.plot(X3_true[:, 0], X3_true[:, 1], X3_true[:, 2], color='black', alpha=0.85, lw=0.8, label='True')
        ax.plot(X3_pred[:, 0], X3_pred[:, 1], X3_pred[:, 2], color='darkred', alpha=0.85, lw=0.8, label=f'AR({p}) pred')
        ax.set_xlabel("x(t)", labelpad=10)
        ax.set_ylabel(f"x(t-{tau_val})", labelpad=10)
        ax.set_zlabel(f"x(t-{2*tau_val})", labelpad=10)
        ax.text2D(0.5, -0.1, f"Channel {chan}", transform=ax.transAxes, ha='center', fontsize=11)
    
    axes[0].legend(frameon=False, loc='upper left')
    fig.subplots_adjust(left=0.02, right=0.98, bottom=0.1, top=0.95, wspace=0.15)
    plt.show()

# Load full raw data for multi-channel sweep
raw_data_full = np.load(filename)
print(f"Full data shape: {raw_data_full.shape}")

# Sweep across channels (using valid channels and time windows)
sweep_channels_plot(
    raw_data_full, 
    channels=[0, 1, 2, 3],
    time_windows=[(1, 3), (1, 3), (2, 4), (1, 3)],
    taus=[30, 30, 30, 30],
    fs_orig=fs_original,
    target_Fs=500
)
Full data shape: (4, 10841760)
../_images/9479fcba505ebc7b8eccb5732c32028a317fc9dfad089faa443ee75219514afc.png