Usage
Instantiate the Base Model
from braindecode.models import SignalJEPA
from huggingface_hub import hf_hub_download
weights_path = hf_hub_download(repo_id="braindecode/SignalJEPA", filename="signal-jepa_16s-60_adeuwv4s.pth")
model_state_dict = torch.load(weights_path)
# Signal-related arguments
# raw: mne.io.BaseRaw
chs_info = raw.info["chs"]
sfreq = raw.info["sfreq"]
model = SignalJEPA(
sfreq=sfreq,
input_window_seconds=2,
chs_info=chs_info,
)
missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False)
assert unexpected_keys == []
# The spatial positional encoder is initialized using the `chs_info`:
assert set(missing_keys) == {"pos_encoder.pos_encoder_spat.weight"}
Instantiate the Downstream Architectures
Contrary to the base model, the downstream architectures are equipped with a classification head which is not pre-trained. Guetschel et al. (2024) arXiv:2403.11772 introduce three downstream architectures:
- a) Contextual downstream architecture
- b) Post-local downstream architecture
- c) Pre-local architecture
from braindecode.models import (
SignalJEPA_Contextual,
SignalJEPA_PreLocal,
SignalJEPA_PostLocal,
)
from huggingface_hub import hf_hub_download
weights_path = hf_hub_download(repo_id="braindecode/SignalJEPA", filename="signal-jepa_16s-60_adeuwv4s.pth")
model_state_dict = torch.load(weights_path)
# Signal-related arguments
# raw: mne.io.BaseRaw
chs_info = raw.info["chs"]
sfreq = raw.info["sfreq"]
# The downstream architectures are equipped with an additional classification head
# which was not pre-trained. It has the following new parameters:
final_layer_keys = {
"final_layer.spat_conv.weight",
"final_layer.spat_conv.bias",
"final_layer.linear.weight",
"final_layer.linear.bias",
}
# a) Contextual downstream architecture
# ----------------------------------
model = SignalJEPA_Contextual(
sfreq=sfreq,
input_window_seconds=2,
chs_info=chs_info,
n_outputs=1,
)
missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False)
assert unexpected_keys == []
# The spatial positional encoder is initialized using the `chs_info`:
assert set(missing_keys) == final_layer_keys | {"pos_encoder.pos_encoder_spat.weight"}
# In the post-local (b) and pre-local (c) architectures, the transformer is discarded:
FILTERED_model_state_dict = {
k: v for k, v in model_state_dict.items() if not any(k.startswith(pre) for pre in ["transformer.", "pos_encoder."])
}
# b) Post-local downstream architecture
# ----------------------------------
model = SignalJEPA_PostLocal(
sfreq=sfreq,
input_window_seconds=2,
n_chans=len(chs_info), # detailed channel info is not needed for this model
n_outputs=1,
)
missing_keys, unexpected_keys = model.load_state_dict(FILTERED_model_state_dict, strict=False)
assert unexpected_keys == []
assert set(missing_keys) == final_layer_keys
# c) Pre-local architecture
# ----------------------
model = SignalJEPA_PreLocal(
sfreq=sfreq,
input_window_seconds=2,
n_chans=len(chs_info), # detailed channel info is not needed for this model
n_outputs=1,
)
missing_keys, unexpected_keys = model.load_state_dict(FILTERED_model_state_dict, strict=False)
assert unexpected_keys == []
assert set(missing_keys) == {
"spatial_conv.1.weight",
"spatial_conv.1.bias",
"final_layer.1.weight",
"final_layer.1.bias",
}
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
๐
Ask for provider support
