NikiPshg commited on
Commit
b2231f4
·
verified ·
1 Parent(s): 611e0d5

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. checkpoint-32000/model.safetensors +3 -0
  2. model.py +128 -0
  3. test.py +51 -0
checkpoint-32000/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:84ffdaba11e18c729c299a64ff916eea5ed8b578307b1882f89d7a740e516f48
3
+ size 379203176
model.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ import torch.nn as nn
5
+ from transformers import AutoModel
6
+
7
+
8
+ import torch
9
+ from torch import nn
10
+ import torch.nn.functional as F
11
+
12
+ class BCEWithLogitsLossLS(nn.Module):
13
+ def __init__(self, label_smoothing=0.1, pos_weight=None, reduction='mean'):
14
+ super(BCEWithLogitsLossLS, self).__init__()
15
+ assert 0 <= label_smoothing < 1, "label_smoothing value must be between 0 and 1."
16
+ self.label_smoothing = label_smoothing
17
+ self.reduction = reduction
18
+ self.bce_with_logits = nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction=reduction)
19
+
20
+ def forward(self, input, target):
21
+ if self.label_smoothing > 0:
22
+ positive_smoothed_labels = 1.0 - self.label_smoothing
23
+ negative_smoothed_labels = self.label_smoothing
24
+ target = target * positive_smoothed_labels + \
25
+ (1 - target) * negative_smoothed_labels
26
+
27
+ loss = self.bce_with_logits(input, target)
28
+ return loss
29
+
30
+ class WavLMForEndpointing(nn.Module):
31
+ def __init__(self, config, n_trainable_layers=6):
32
+ super().__init__()
33
+ self.wavlm = AutoModel.from_pretrained('microsoft/wavlm-base-plus', config=config)
34
+ self.config = config
35
+ self.n_trainable_layers = n_trainable_layers
36
+
37
+ for param in self.wavlm.parameters():
38
+ param.requires_grad = False
39
+
40
+ if self.n_trainable_layers > 0:
41
+ for i in range(self.n_trainable_layers):
42
+ for param in self.wavlm.encoder.layers[-(i+1)].parameters():
43
+ param.requires_grad = True
44
+
45
+ self.pool_attention = nn.Sequential(
46
+ nn.Linear(config.hidden_size, 256),
47
+ nn.Tanh(),
48
+ nn.Linear(256, 1)
49
+ )
50
+
51
+ self.classifier = nn.Sequential(
52
+ nn.Linear(config.hidden_size, 256),
53
+ nn.LayerNorm(256),
54
+ nn.GELU(),
55
+ nn.Dropout(0.1),
56
+ nn.Linear(256, 64),
57
+ nn.LayerNorm(64),
58
+ nn.GELU(),
59
+ nn.Linear(64, 1)
60
+ )
61
+
62
+ for module in self.classifier:
63
+ if isinstance(module, nn.Linear):
64
+ module.weight.data.normal_(mean=0.0, std=0.1)
65
+ if module.bias is not None:
66
+ module.bias.data.zero_()
67
+
68
+ for module in self.pool_attention:
69
+ if isinstance(module, nn.Linear):
70
+ module.weight.data.normal_(mean=0.0, std=0.1)
71
+ if module.bias is not None:
72
+ module.bias.data.zero_()
73
+
74
+ def attention_pool(self, hidden_states, attention_mask):
75
+ attention_weights = self.pool_attention(hidden_states)
76
+
77
+ if attention_mask is None:
78
+ raise ValueError("attention_mask must be provided for attention pooling")
79
+
80
+ attention_weights = attention_weights + (
81
+ (1.0 - attention_mask.unsqueeze(-1).to(attention_weights.dtype)) * -1e9
82
+ )
83
+
84
+ attention_weights = F.softmax(attention_weights, dim=1)
85
+
86
+ # Apply attention to hidden states
87
+ weighted_sum = torch.sum(hidden_states * attention_weights, dim=1)
88
+
89
+ return weighted_sum
90
+
91
+ def forward(self, input_values, attention_mask=None, labels=None):
92
+ outputs = self.wavlm(input_values, attention_mask=attention_mask)
93
+ hidden_states = outputs[0]
94
+
95
+ if attention_mask is not None:
96
+ input_length = attention_mask.size(1)
97
+ hidden_length = hidden_states.size(1)
98
+ ratio = input_length / hidden_length
99
+ indices = (torch.arange(hidden_length, device=attention_mask.device) * ratio).long()
100
+ attention_mask = attention_mask[:, indices]
101
+ attention_mask = attention_mask.bool()
102
+ else:
103
+ attention_mask = None
104
+
105
+ pooled = self.attention_pool(hidden_states, attention_mask)
106
+
107
+ logits = self.classifier(pooled)
108
+
109
+ if torch.isnan(logits).any():
110
+ raise ValueError("NaN values detected in logits")
111
+
112
+ if labels is not None:
113
+ pos_weight = ((labels == 0).sum() / (labels == 1).sum()).clamp(min=0.1, max=10.0)
114
+ loss_fct = BCEWithLogitsLossLS(pos_weight=pos_weight)
115
+ labels = labels.float()
116
+ loss = loss_fct(logits.view(-1), labels.view(-1))
117
+
118
+ l2_lambda = 0.01
119
+ l2_reg = torch.tensor(0., device=logits.device)
120
+ for param in self.classifier.parameters():
121
+ l2_reg += torch.norm(param)
122
+ loss += l2_lambda * l2_reg
123
+
124
+ probs = torch.sigmoid(logits.detach())
125
+ return {"loss": loss, "logits": probs}
126
+
127
+ probs = torch.sigmoid(logits)
128
+ return {"logits": probs}
test.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model import WavLMForEndpointing
2
+ import torchaudio
3
+ import transformers
4
+ import numpy as np
5
+ from safetensors import safe_open
6
+ import torch
7
+
8
+ MODEL_NAME = 'microsoft/wavlm-base-plus'
9
+
10
+ processor = transformers.AutoFeatureExtractor.from_pretrained(
11
+ MODEL_NAME
12
+ )
13
+
14
+ config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
15
+ model = WavLMForEndpointing(config)
16
+
17
+ checkpoint_path = "/home/nikita/wavlm-endpointing-model/checkpoint-29000/model.safetensors"
18
+
19
+ with safe_open(checkpoint_path, framework="pt", device="cpu") as f:
20
+ state_dict = {key: f.get_tensor(key) for key in f.keys()}
21
+
22
+ model.load_state_dict(state_dict)
23
+ print("Веса успешно загружены из safetensors")
24
+
25
+ model.eval()
26
+
27
+ while True:
28
+ print('1234')
29
+ audio_path = str(input())
30
+ waveform, sample_rate = torchaudio.load(audio_path)
31
+
32
+ if sample_rate != 16000:
33
+ resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
34
+ waveform = resampler(waveform)
35
+
36
+ if waveform.shape[0] > 1:
37
+ waveform = waveform.mean(dim=0, keepdim=True)
38
+
39
+
40
+ inputs = processor(
41
+ waveform.squeeze().numpy(),
42
+ sampling_rate=16000,
43
+ return_tensors="pt",
44
+ padding=False,
45
+ truncation=False
46
+ )
47
+
48
+ with torch.no_grad():
49
+ result = model(**inputs)
50
+
51
+ print(result)