sjepa

Usage

Instantiate the Base Model

from braindecode.models import SignalJEPA
from huggingface_hub import hf_hub_download

weights_path = hf_hub_download(repo_id="braindecode/SignalJEPA", filename="signal-jepa_16s-60_adeuwv4s.pth")
model_state_dict = torch.load(weights_path)

# Signal-related arguments
# raw: mne.io.BaseRaw
chs_info = raw.info["chs"]
sfreq = raw.info["sfreq"]

model = SignalJEPA(
    sfreq=sfreq,
    input_window_seconds=2,
    chs_info=chs_info,
)
missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False)
assert unexpected_keys == []
# The spatial positional encoder is initialized using the `chs_info`:
assert set(missing_keys) == {"pos_encoder.pos_encoder_spat.weight"}

Instantiate the Downstream Architectures

Contrary to the base model, the downstream architectures are equipped with a classification head which is not pre-trained. Guetschel et al. (2024) arXiv:2403.11772 introduce three downstream architectures:

  • a) Contextual downstream architecture
  • b) Post-local downstream architecture
  • c) Pre-local architecture
from braindecode.models import (
    SignalJEPA_Contextual,
    SignalJEPA_PreLocal,
    SignalJEPA_PostLocal,
)
from huggingface_hub import hf_hub_download

weights_path = hf_hub_download(repo_id="braindecode/SignalJEPA", filename="signal-jepa_16s-60_adeuwv4s.pth")
model_state_dict = torch.load(weights_path)

# Signal-related arguments
# raw: mne.io.BaseRaw
chs_info = raw.info["chs"]
sfreq = raw.info["sfreq"]

# The downstream architectures are equipped with an additional classification head
# which was not pre-trained. It has the following new parameters:
final_layer_keys = {
    "final_layer.spat_conv.weight",
    "final_layer.spat_conv.bias",
    "final_layer.linear.weight",
    "final_layer.linear.bias",
}


# a) Contextual downstream architecture
#    ----------------------------------
model = SignalJEPA_Contextual(
    sfreq=sfreq,
    input_window_seconds=2,
    chs_info=chs_info,
    n_outputs=1,
)
missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False)
assert unexpected_keys == []
# The spatial positional encoder is initialized using the `chs_info`:
assert set(missing_keys) == final_layer_keys | {"pos_encoder.pos_encoder_spat.weight"}

# In the post-local (b) and pre-local (c) architectures, the transformer is discarded:
FILTERED_model_state_dict = {
    k: v for k, v in model_state_dict.items() if not any(k.startswith(pre) for pre in ["transformer.", "pos_encoder."])
}


# b) Post-local downstream architecture
#    ----------------------------------
model = SignalJEPA_PostLocal(
    sfreq=sfreq,
    input_window_seconds=2,
    n_chans=len(chs_info),  # detailed channel info is not needed for this model
    n_outputs=1,
)
missing_keys, unexpected_keys = model.load_state_dict(FILTERED_model_state_dict, strict=False)
assert unexpected_keys == []
assert set(missing_keys) == final_layer_keys


# c) Pre-local architecture
#    ----------------------
model = SignalJEPA_PreLocal(
    sfreq=sfreq,
    input_window_seconds=2,
    n_chans=len(chs_info),  # detailed channel info is not needed for this model
    n_outputs=1,
)
missing_keys, unexpected_keys = model.load_state_dict(FILTERED_model_state_dict, strict=False)
assert unexpected_keys == []
assert set(missing_keys) == {
    "spatial_conv.1.weight",
    "spatial_conv.1.bias",
    "final_layer.1.weight",
    "final_layer.1.bias",
}
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support