Spaces:
Running
on
Zero
Running
on
Zero
App-redesign (#1)
Browse files- New app design, removing unnecessary code (5c7c0e25f77218ec5162453df6d15fcd71550f82)
- Merge branch 'gagan/redo-design' into pr/1 (1b7da9f7b82d5953b20899553e700fee83290a79)
- Small fix in launch (8ff0489f2c0201c4233e566b85acbd78ccfc632c)
- NatureLM/augmentations.py +0 -349
- NatureLM/checkpoint_utils.py +2 -27
- NatureLM/dataset.py +0 -550
- NatureLM/dist_utils.py +0 -109
- NatureLM/infer.py +83 -33
- NatureLM/logger.py +0 -190
- NatureLM/models/NatureLM.py +1 -1
- NatureLM/optims.py +0 -154
- NatureLM/processors.py +2 -2
- NatureLM/runner.py +0 -515
- NatureLM/storage_utils.py +0 -26
- NatureLM/task_metric_utils.py +0 -283
- NatureLM/task_metrics.py +0 -128
- NatureLM/utils.py +1 -26
- Space.yaml +1 -1
- app.py +252 -319
- requirements.txt +19 -31
NatureLM/augmentations.py
DELETED
|
@@ -1,349 +0,0 @@
|
|
| 1 |
-
import logging
|
| 2 |
-
import random
|
| 3 |
-
|
| 4 |
-
import numpy as np
|
| 5 |
-
import torch as th
|
| 6 |
-
from torch import nn
|
| 7 |
-
from torch.nn import functional as F
|
| 8 |
-
|
| 9 |
-
from NatureLM.utils import mel_frequencies
|
| 10 |
-
|
| 11 |
-
logger = logging.getLogger(__name__)
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
class RevEcho(nn.Module):
|
| 15 |
-
"""
|
| 16 |
-
Hacky Reverb but runs on GPU without slowing down training. This reverb adds a
|
| 17 |
-
succession of attenuated echos of the input signal to itself. Intuitively, the delay
|
| 18 |
-
of the first echo will happen after roughly 2x the radius of the room and is
|
| 19 |
-
controlled by `first_delay`. Then RevEcho keeps adding echos with the same delay and
|
| 20 |
-
further attenuation until the amplitude ratio between the last and first echo is
|
| 21 |
-
1e-3. The attenuation factor and the number of echos to adds is controlled by RT60
|
| 22 |
-
(measured in seconds). RT60 is the average time to get to -60dB (n.b. volume is
|
| 23 |
-
measured over the squared amplitude so this matches the 1e-3 ratio).
|
| 24 |
-
|
| 25 |
-
At each call to RevEcho, `first_delay`, `initial` and `RT60` are sampled from their
|
| 26 |
-
range. Then, to prevent this reverb from being too regular, the delay time is
|
| 27 |
-
resampled uniformly within `first_delay +/- 10%`, as controlled by the `jitter`
|
| 28 |
-
parameter.
|
| 29 |
-
|
| 30 |
-
Finally, for a denser reverb, multiple trains of echos are added with different
|
| 31 |
-
jitter noises.
|
| 32 |
-
|
| 33 |
-
Args:
|
| 34 |
-
- initial: amplitude of the first echo as a fraction of the input signal. For
|
| 35 |
-
each sample, actually sampled from `[0, initial]`. Larger values means louder
|
| 36 |
-
reverb. Physically, this would depend on the absorption of the room walls.
|
| 37 |
-
- rt60: range of values to sample the RT60 in seconds, i.e. after RT60 seconds,
|
| 38 |
-
the echo amplitude is 1e-3 of the first echo. The default values follow the
|
| 39 |
-
recommendations of https://arxiv.org/ftp/arxiv/papers/2001/2001.08662.pdf,
|
| 40 |
-
Section 2.4. Physically this would also be related to the absorption of the
|
| 41 |
-
room walls and there is likely a relation between `RT60` and `initial`, which
|
| 42 |
-
we ignore here.
|
| 43 |
-
- first_delay: range of values to sample the first echo delay in seconds. The
|
| 44 |
-
default values are equivalent to sampling a room of 3 to 10 meters.
|
| 45 |
-
- repeat: how many train of echos with differents jitters to add. Higher values
|
| 46 |
-
means a denser reverb.
|
| 47 |
-
- jitter: jitter used to make each repetition of the reverb echo train slightly
|
| 48 |
-
different. For instance a jitter of 0.1 means the delay between two echos will
|
| 49 |
-
be in the range `first_delay +- 10%`, with the jittering noise being resampled
|
| 50 |
-
after each single echo.
|
| 51 |
-
- keep_clean: fraction of the reverb of the clean speech to add back to the
|
| 52 |
-
ground truth. 0 = dereverberation, 1 = no dereverberation.
|
| 53 |
-
- sample_rate: sample rate of the input signals.
|
| 54 |
-
"""
|
| 55 |
-
|
| 56 |
-
def __init__(
|
| 57 |
-
self,
|
| 58 |
-
proba=0.5,
|
| 59 |
-
initial=0.3,
|
| 60 |
-
rt60=(0.3, 1.3),
|
| 61 |
-
first_delay=(0.01, 0.03),
|
| 62 |
-
repeat=3,
|
| 63 |
-
jitter=0.1,
|
| 64 |
-
keep_clean=0.1,
|
| 65 |
-
sample_rate=16000,
|
| 66 |
-
rng=None,
|
| 67 |
-
seed=42,
|
| 68 |
-
):
|
| 69 |
-
super().__init__()
|
| 70 |
-
|
| 71 |
-
self.proba = proba
|
| 72 |
-
self.initial = initial
|
| 73 |
-
self.rt60 = rt60
|
| 74 |
-
self.first_delay = first_delay
|
| 75 |
-
self.repeat = repeat
|
| 76 |
-
self.jitter = jitter
|
| 77 |
-
self.keep_clean = keep_clean
|
| 78 |
-
self.sample_rate = sample_rate
|
| 79 |
-
self.seed = seed
|
| 80 |
-
self.rng = rng if rng is not None else random.Random(self.seed)
|
| 81 |
-
|
| 82 |
-
def _reverb(self, source, initial, first_delay, rt60):
|
| 83 |
-
"""
|
| 84 |
-
Return the reverb for a single source.
|
| 85 |
-
"""
|
| 86 |
-
length = source.shape[-1]
|
| 87 |
-
reverb = th.zeros_like(source)
|
| 88 |
-
|
| 89 |
-
for _ in range(self.repeat):
|
| 90 |
-
frac = 1 # what fraction of the first echo amplitude is still here
|
| 91 |
-
echo = initial * source
|
| 92 |
-
while frac > 1e-3:
|
| 93 |
-
# First jitter noise for the delay
|
| 94 |
-
jitter = 1 + self.jitter * self.rng.uniform(-1, 1)
|
| 95 |
-
delay = min(1 + int(jitter * first_delay * self.sample_rate), length)
|
| 96 |
-
|
| 97 |
-
# Delay the echo in time by padding with zero on the left
|
| 98 |
-
echo = F.pad(echo[:, :, :-delay], (delay, 0))
|
| 99 |
-
reverb += echo
|
| 100 |
-
|
| 101 |
-
# Second jitter noise for the attenuation
|
| 102 |
-
jitter = 1 + self.jitter * self.rng.uniform(-1, 1)
|
| 103 |
-
# we want, with `d` the attenuation, d**(rt60 / first_ms) = 1e-3
|
| 104 |
-
# i.e. log10(d) = -3 * first_ms / rt60, so that
|
| 105 |
-
attenuation = 10 ** (-3 * jitter * first_delay / rt60)
|
| 106 |
-
echo *= attenuation
|
| 107 |
-
frac *= attenuation
|
| 108 |
-
|
| 109 |
-
return reverb
|
| 110 |
-
|
| 111 |
-
def forward(self, samples):
|
| 112 |
-
if self.rng.random() >= self.proba:
|
| 113 |
-
return samples
|
| 114 |
-
|
| 115 |
-
raw_wav = samples.get("raw_wav", None)
|
| 116 |
-
|
| 117 |
-
# add channel dimension if not exist
|
| 118 |
-
if raw_wav.dim() == 2:
|
| 119 |
-
raw_wav = raw_wav.unsqueeze(1)
|
| 120 |
-
|
| 121 |
-
# Sample characteristics for the reverb
|
| 122 |
-
initial = self.rng.random() * self.initial
|
| 123 |
-
first_delay = self.rng.uniform(*self.first_delay)
|
| 124 |
-
rt60 = self.rng.uniform(*self.rt60)
|
| 125 |
-
|
| 126 |
-
reverb_wav = self._reverb(raw_wav, initial, first_delay, rt60)
|
| 127 |
-
raw_wav += self.keep_clean * reverb_wav
|
| 128 |
-
|
| 129 |
-
# remove channel dimension
|
| 130 |
-
if raw_wav.dim() == 3 and raw_wav.shape[1] == 1:
|
| 131 |
-
raw_wav = raw_wav.squeeze(1)
|
| 132 |
-
|
| 133 |
-
samples["raw_wav"] = raw_wav
|
| 134 |
-
return samples
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
class BandMask(nn.Module):
|
| 138 |
-
"""
|
| 139 |
-
Maskes bands of frequencies. Similar to Park, Daniel S., et al.
|
| 140 |
-
"Specaugment: A simple data augmentation method for automatic speech recognition."
|
| 141 |
-
(https://arxiv.org/pdf/1904.08779.pdf) but over the waveform.
|
| 142 |
-
"""
|
| 143 |
-
|
| 144 |
-
def __init__(self, maxwidth=0.2, bands=120, sample_rate=16_000, rng=None, seed=42):
|
| 145 |
-
"""__init__.
|
| 146 |
-
|
| 147 |
-
:param maxwidth: the maximum width to remove
|
| 148 |
-
:param bands: number of bands
|
| 149 |
-
:param sample_rate: signal sample rate
|
| 150 |
-
"""
|
| 151 |
-
super().__init__()
|
| 152 |
-
self.maxwidth = maxwidth
|
| 153 |
-
self.bands = bands
|
| 154 |
-
self.sample_rate = sample_rate
|
| 155 |
-
self.seed = seed
|
| 156 |
-
self.rng = rng if rng is not None else random.Random(self.seed)
|
| 157 |
-
|
| 158 |
-
def forward(self, samples):
|
| 159 |
-
raw_wav = samples.get("raw_wav", None)
|
| 160 |
-
|
| 161 |
-
# add channel dimension if not exist
|
| 162 |
-
if raw_wav.dim() == 2:
|
| 163 |
-
raw_wav = raw_wav.unsqueeze(1)
|
| 164 |
-
|
| 165 |
-
bands = self.bands
|
| 166 |
-
bandwidth = int(abs(self.maxwidth) * bands)
|
| 167 |
-
mels = mel_frequencies(bands, 40, self.sample_rate / 2) / self.sample_rate
|
| 168 |
-
low = self.rng.randrange(bands)
|
| 169 |
-
high = self.rng.randrange(low, min(bands, low + bandwidth))
|
| 170 |
-
|
| 171 |
-
filters = LowPassFilters([mels[low], mels[high]]).to(raw_wav.device)
|
| 172 |
-
|
| 173 |
-
low, midlow = filters(raw_wav)
|
| 174 |
-
# band pass filtering
|
| 175 |
-
out = raw_wav - midlow + low
|
| 176 |
-
|
| 177 |
-
# remove channel dimension
|
| 178 |
-
if out.dim() == 3 and out.shape[1] == 1:
|
| 179 |
-
out = out.squeeze(1)
|
| 180 |
-
|
| 181 |
-
samples["raw_wav"] = out
|
| 182 |
-
return samples
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
class Shift(nn.Module):
|
| 186 |
-
def __init__(self, shift=8192, same=False, rngth=None):
|
| 187 |
-
"""
|
| 188 |
-
:param shift: randomly shifts the signals up to a given factor
|
| 189 |
-
:param same: shifts both clean and noisy files by the same factor
|
| 190 |
-
"""
|
| 191 |
-
super().__init__()
|
| 192 |
-
self.shift = shift
|
| 193 |
-
self.same = same
|
| 194 |
-
self.rngth = rngth
|
| 195 |
-
|
| 196 |
-
def forward(self, samples):
|
| 197 |
-
raw_wav = samples.get("raw_wav", None)
|
| 198 |
-
batch, channels, length = raw_wav.shape
|
| 199 |
-
length = length - self.shift
|
| 200 |
-
if self.shift > 0:
|
| 201 |
-
offsets = th.randint(
|
| 202 |
-
self.shift, [1 if self.same else batch, 1, 1], device=raw_wav.device, generator=self.rngth
|
| 203 |
-
)
|
| 204 |
-
offsets = offsets.expand(-1, channels, -1)
|
| 205 |
-
indexes = th.arange(length, device=raw_wav.device)
|
| 206 |
-
import pdb
|
| 207 |
-
|
| 208 |
-
pdb.set_trace()
|
| 209 |
-
raw_wav = raw_wav.gather(2, indexes + offsets)
|
| 210 |
-
samples["raw_wav"] = raw_wav
|
| 211 |
-
return samples
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
class TimeScale(nn.Module):
|
| 215 |
-
"""Fast time scale."""
|
| 216 |
-
|
| 217 |
-
def __init__(self, scale=2.0, target=1, rngnp=None, seed=42):
|
| 218 |
-
"""
|
| 219 |
-
:param scale: randomly scales up to this maximum factor
|
| 220 |
-
"""
|
| 221 |
-
super().__init__()
|
| 222 |
-
self.scale = scale
|
| 223 |
-
self.target = target
|
| 224 |
-
self.seed = seed
|
| 225 |
-
self.rngnp = rngnp if rngnp is not None else np.random.default_rng(seed=self.seed)
|
| 226 |
-
|
| 227 |
-
def forward(self, samples):
|
| 228 |
-
try:
|
| 229 |
-
raw_wav = samples.get("raw_wav")
|
| 230 |
-
except KeyError:
|
| 231 |
-
logger.error("Missing required key 'raw_wav' in samples dict")
|
| 232 |
-
raise
|
| 233 |
-
|
| 234 |
-
if "padding_mask" in samples:
|
| 235 |
-
masks = samples.get("padding_mask")
|
| 236 |
-
else:
|
| 237 |
-
masks = th.ones_like(raw_wav)
|
| 238 |
-
|
| 239 |
-
# add channel dimension if not exist
|
| 240 |
-
if raw_wav.dim() == 2:
|
| 241 |
-
raw_wav = raw_wav.unsqueeze(1)
|
| 242 |
-
masks = masks.unsqueeze(1)
|
| 243 |
-
|
| 244 |
-
# what to augment: noise, clean, or both
|
| 245 |
-
if self.target == -1:
|
| 246 |
-
targets = [i for i in range(raw_wav.shape[0])]
|
| 247 |
-
else:
|
| 248 |
-
targets = [self.target]
|
| 249 |
-
|
| 250 |
-
for t in targets:
|
| 251 |
-
signal = raw_wav[t]
|
| 252 |
-
scaling = np.power(self.scale, self.rngnp.uniform(-1, 1))
|
| 253 |
-
output_size = int(signal.shape[-1] * scaling)
|
| 254 |
-
ref = th.arange(output_size, device=signal.device, dtype=signal.dtype).div_(scaling)
|
| 255 |
-
|
| 256 |
-
ref1 = ref.clone().type(th.int64)
|
| 257 |
-
ref2 = th.min(ref1 + 1, th.full_like(ref1, signal.shape[-1] - 1, dtype=th.int64))
|
| 258 |
-
r = ref - ref1.type(ref.type())
|
| 259 |
-
scaled_signal = signal[..., ref1] * (1 - r) + signal[..., ref2] * r
|
| 260 |
-
scaled_masks = masks[t][..., ref1] * (1 - r) + masks[t][..., ref2] * r
|
| 261 |
-
|
| 262 |
-
# trim or zero pad to the original size
|
| 263 |
-
if scaled_signal.shape[-1] > signal.shape[-1]:
|
| 264 |
-
nframes_offset = (scaled_signal.shape[-1] - signal.shape[-1]) // 2
|
| 265 |
-
scaled_signal = scaled_signal[..., nframes_offset : nframes_offset + signal.shape[-1]]
|
| 266 |
-
scaled_masks = scaled_masks[..., nframes_offset : nframes_offset + signal.shape[-1]]
|
| 267 |
-
else:
|
| 268 |
-
nframes_diff = signal.shape[-1] - scaled_signal.shape[-1]
|
| 269 |
-
pad_left = int(np.random.uniform() * nframes_diff)
|
| 270 |
-
pad_right = nframes_diff - pad_left
|
| 271 |
-
scaled_signal = F.pad(
|
| 272 |
-
input=scaled_signal, pad=(pad_left, pad_right, 0, 0, 0, 0), mode="constant", value=0
|
| 273 |
-
)
|
| 274 |
-
scaled_masks = F.pad(
|
| 275 |
-
input=scaled_masks, pad=(pad_left, pad_right, 0, 0, 0, 0), mode="constant", value=0
|
| 276 |
-
)
|
| 277 |
-
raw_wav[t] = scaled_signal
|
| 278 |
-
masks[t] = scaled_masks
|
| 279 |
-
|
| 280 |
-
# remove channel dimension
|
| 281 |
-
if raw_wav.dim() == 3 and raw_wav.shape[1] == 1:
|
| 282 |
-
raw_wav = raw_wav.squeeze(1)
|
| 283 |
-
masks = masks.squeeze(1)
|
| 284 |
-
|
| 285 |
-
samples["raw_wav"] = raw_wav
|
| 286 |
-
samples["padding_mask"] = masks
|
| 287 |
-
|
| 288 |
-
return samples
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
class Flip(nn.Module):
|
| 292 |
-
def __init__(self, p=0.0, rngth=None):
|
| 293 |
-
super(Flip, self).__init__()
|
| 294 |
-
|
| 295 |
-
self.p = p
|
| 296 |
-
self.rngth = rngth
|
| 297 |
-
|
| 298 |
-
def forward(self, samples):
|
| 299 |
-
raw_wav = samples["raw_wav"]
|
| 300 |
-
if raw_wav.dim() > 2:
|
| 301 |
-
flip_mask = th.rand(raw_wav.shape[0], device=raw_wav.device, generator=self.rngth) <= self.p
|
| 302 |
-
raw_wav[flip_mask] = raw_wav[flip_mask].flip(-1)
|
| 303 |
-
else:
|
| 304 |
-
if th.rand(1, generator=self.rngth) <= self.p:
|
| 305 |
-
raw_wav = raw_wav.flip(0)
|
| 306 |
-
samples["raw_wav"] = raw_wav
|
| 307 |
-
return samples
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
class LowPassFilters(th.nn.Module):
|
| 311 |
-
"""
|
| 312 |
-
Bank of low pass filters.
|
| 313 |
-
|
| 314 |
-
Args:
|
| 315 |
-
cutoffs (list[float]): list of cutoff frequencies, in [0, 1] expressed as `f/f_s` where
|
| 316 |
-
f_s is the samplerate.
|
| 317 |
-
width (int | None): width of the filters (i.e. kernel_size=2 * width + 1).
|
| 318 |
-
Default to `2 / min(cutoffs)`. Longer filters will have better attenuation
|
| 319 |
-
but more side effects.
|
| 320 |
-
Shape:
|
| 321 |
-
- Input: `(*, T)`
|
| 322 |
-
- Output: `(F, *, T` with `F` the len of `cutoffs`.
|
| 323 |
-
"""
|
| 324 |
-
|
| 325 |
-
def __init__(self, cutoffs: list, width: int | None = None):
|
| 326 |
-
super().__init__()
|
| 327 |
-
|
| 328 |
-
self.cutoffs = cutoffs
|
| 329 |
-
|
| 330 |
-
if not width:
|
| 331 |
-
width = int(2 / min(cutoffs))
|
| 332 |
-
self.width = width
|
| 333 |
-
|
| 334 |
-
window = th.hamming_window(2 * width + 1, periodic=False)
|
| 335 |
-
t = np.arange(-width, width + 1, dtype=np.float32)
|
| 336 |
-
filters = []
|
| 337 |
-
for cutoff in cutoffs:
|
| 338 |
-
sinc = th.from_numpy(np.sinc(2 * cutoff * t))
|
| 339 |
-
filters.append(2 * cutoff * sinc * window)
|
| 340 |
-
self.register_buffer("filters", th.stack(filters).unsqueeze(1))
|
| 341 |
-
|
| 342 |
-
def forward(self, input):
|
| 343 |
-
*others, t = input.shape
|
| 344 |
-
input = input.view(-1, 1, t)
|
| 345 |
-
out = F.conv1d(input, self.filters, padding=self.width)
|
| 346 |
-
return out.permute(1, 0, 2).reshape(-1, *others, t)
|
| 347 |
-
|
| 348 |
-
def __repr__(self):
|
| 349 |
-
return "LossPassFilters(width={},cutoffs={})".format(self.width, self.cutoffs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
NatureLM/checkpoint_utils.py
CHANGED
|
@@ -42,27 +42,6 @@ def get_state_dict(model, drop_untrained_params: bool = True) -> dict[str, Any]:
|
|
| 42 |
return state_dict
|
| 43 |
|
| 44 |
|
| 45 |
-
def torch_save_to_bucket(save_obj: Any, save_path: Union[str, os.PathLike], compress: bool = True) -> None:
|
| 46 |
-
"""Save an object directly to GCS bucket without intermediate disk storage.
|
| 47 |
-
|
| 48 |
-
Args:
|
| 49 |
-
save_obj: Object to save (usually model state dict or checkpoint)
|
| 50 |
-
save_path: Path to save in GCS bucket (must be gs:// path)
|
| 51 |
-
compress: Whether to use compression. Default: True
|
| 52 |
-
"""
|
| 53 |
-
if not is_gcs_path(save_path):
|
| 54 |
-
raise ValueError("save_path must be a GCS path")
|
| 55 |
-
|
| 56 |
-
# save to a temporary local file and then upload to GCS
|
| 57 |
-
with tempfile.NamedTemporaryFile() as tmp:
|
| 58 |
-
torch.save(save_obj, tmp.name, _use_new_zipfile_serialization=compress)
|
| 59 |
-
try:
|
| 60 |
-
save_path.upload_from(tmp.name)
|
| 61 |
-
except Exception as e:
|
| 62 |
-
logger.error(f"Error saving to GCP bucket: {e}")
|
| 63 |
-
raise e
|
| 64 |
-
|
| 65 |
-
|
| 66 |
def save_model_checkpoint(
|
| 67 |
model: nn.Module,
|
| 68 |
save_path: Union[str, os.PathLike],
|
|
@@ -82,7 +61,7 @@ def save_model_checkpoint(
|
|
| 82 |
extention (str): Extension to use for the checkpoint file. Default: "pth".
|
| 83 |
**objects_to_save: Additional objects to save, e.g. optimizer state dict, etc.
|
| 84 |
"""
|
| 85 |
-
if not
|
| 86 |
raise FileNotFoundError(f"Directory {os.path.dirname(save_path)} does not exist.")
|
| 87 |
|
| 88 |
model_no_ddp = maybe_unwrap_dist_model(model, use_distributed)
|
|
@@ -93,8 +72,4 @@ def save_model_checkpoint(
|
|
| 93 |
}
|
| 94 |
|
| 95 |
logger.info("Saving checkpoint to {}.".format(save_path))
|
| 96 |
-
|
| 97 |
-
if is_gcs_path(save_path):
|
| 98 |
-
torch_save_to_bucket(save_obj, save_path)
|
| 99 |
-
else:
|
| 100 |
-
torch.save(save_obj, save_path)
|
|
|
|
| 42 |
return state_dict
|
| 43 |
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
def save_model_checkpoint(
|
| 46 |
model: nn.Module,
|
| 47 |
save_path: Union[str, os.PathLike],
|
|
|
|
| 61 |
extention (str): Extension to use for the checkpoint file. Default: "pth".
|
| 62 |
**objects_to_save: Additional objects to save, e.g. optimizer state dict, etc.
|
| 63 |
"""
|
| 64 |
+
if not os.path.exists(os.path.dirname(save_path)):
|
| 65 |
raise FileNotFoundError(f"Directory {os.path.dirname(save_path)} does not exist.")
|
| 66 |
|
| 67 |
model_no_ddp = maybe_unwrap_dist_model(model, use_distributed)
|
|
|
|
| 72 |
}
|
| 73 |
|
| 74 |
logger.info("Saving checkpoint to {}.".format(save_path))
|
| 75 |
+
torch.save(save_obj, save_path)
|
|
|
|
|
|
|
|
|
|
|
|
NatureLM/dataset.py
DELETED
|
@@ -1,550 +0,0 @@
|
|
| 1 |
-
# Copyright (2024) Earth Species Project
|
| 2 |
-
#
|
| 3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
-
# you may not use this file except in compliance with the License.
|
| 5 |
-
# You may obtain a copy of the License at
|
| 6 |
-
#
|
| 7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
-
#
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
"""
|
| 17 |
-
Mixing examples.
|
| 18 |
-
Can mix:
|
| 19 |
-
- base: options-detection add: open-ended:
|
| 20 |
-
Take all open-ended labels. Add them to the options. Add them to the labels.
|
| 21 |
-
- base: open-ended, add: open-ended
|
| 22 |
-
Concatenate labels
|
| 23 |
-
"""
|
| 24 |
-
|
| 25 |
-
import glob
|
| 26 |
-
import json
|
| 27 |
-
import os
|
| 28 |
-
import random
|
| 29 |
-
from collections import defaultdict
|
| 30 |
-
from pathlib import Path
|
| 31 |
-
from typing import Literal
|
| 32 |
-
|
| 33 |
-
import numpy as np
|
| 34 |
-
import soundfile as sf
|
| 35 |
-
import torch
|
| 36 |
-
from torch.nn.utils.rnn import pad_sequence
|
| 37 |
-
from torch.utils.data import Dataset
|
| 38 |
-
|
| 39 |
-
from NatureLM.utils import snr_scale, time_scale
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
def write_example_to_file(base_filename, audio, sr=16000, suffix="_output", save_dir="debug_outputs"):
|
| 43 |
-
"""
|
| 44 |
-
Writes the audio tensor to a file for debugging or inspection purposes.
|
| 45 |
-
|
| 46 |
-
Args:
|
| 47 |
-
base_filename (str): The base name of the original file.
|
| 48 |
-
audio (torch.Tensor or numpy.ndarray): The audio waveform to save.
|
| 49 |
-
sr (int): Sampling rate of the audio (default: 16000 Hz).
|
| 50 |
-
suffix (str): Optional suffix to append to the filename.
|
| 51 |
-
save_dir (str): Directory where the files will be saved.
|
| 52 |
-
"""
|
| 53 |
-
if isinstance(audio, torch.Tensor):
|
| 54 |
-
audio = audio.numpy() # Convert to numpy if necessary
|
| 55 |
-
|
| 56 |
-
# Ensure the save directory exists
|
| 57 |
-
os.makedirs(save_dir, exist_ok=True)
|
| 58 |
-
|
| 59 |
-
# Create the output file path
|
| 60 |
-
filename = f"{os.path.splitext(base_filename)[0]}{suffix}.wav"
|
| 61 |
-
output_path = os.path.join(save_dir, filename)
|
| 62 |
-
|
| 63 |
-
try:
|
| 64 |
-
# Write the audio to the file
|
| 65 |
-
sf.write(output_path, audio, sr)
|
| 66 |
-
print(f"Saved audio to {output_path}")
|
| 67 |
-
except Exception as e:
|
| 68 |
-
print(f"Failed to write audio to file: {e}")
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
# Example usage in your code
|
| 72 |
-
# write_example_to_file(os.path.basename(ann["path"]), audio, suffix="_ts")
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
def collater(samples):
|
| 76 |
-
"""Collate samples into a batch.
|
| 77 |
-
|
| 78 |
-
Samples is a list of dictionaries, each containing the following keys:
|
| 79 |
-
- raw_wav: a list of tensors containing the raw audio waveform
|
| 80 |
-
- text: a list of strings containing the text
|
| 81 |
-
- task: a list of strings containing the task
|
| 82 |
-
- id: a list of strings containing the id
|
| 83 |
-
- prompt: a list of strings containing the prompt
|
| 84 |
-
- index: a list of integers containing the index
|
| 85 |
-
|
| 86 |
-
The indiviudal audio waveforms will be stacked along the batch dimension for easier
|
| 87 |
-
processing in the audio model. To keep which audio belongs to which sample, we add
|
| 88 |
-
the audio_chunk_sizes key to the batch dictionary.
|
| 89 |
-
"""
|
| 90 |
-
flat_raw_wav = []
|
| 91 |
-
audio_chunk_sizes = []
|
| 92 |
-
|
| 93 |
-
for s in samples:
|
| 94 |
-
chunk_size = len(s["raw_wav"])
|
| 95 |
-
audio_chunk_sizes.append(chunk_size)
|
| 96 |
-
flat_raw_wav.extend(s["raw_wav"])
|
| 97 |
-
# raw_wav = [torch.from_numpy(a) for a in flat_raw_wav]
|
| 98 |
-
raw_wav = flat_raw_wav
|
| 99 |
-
raw_wav_length = torch.tensor([len(a) for a in raw_wav])
|
| 100 |
-
raw_wav = pad_sequence(raw_wav, batch_first=True, padding_value=0)
|
| 101 |
-
paddding_mask = torch.arange(raw_wav.size(1)).unsqueeze(0) >= raw_wav_length.unsqueeze(1)
|
| 102 |
-
|
| 103 |
-
text = [s["text"] for s in samples]
|
| 104 |
-
prompt = [s["prompt"] for s in samples]
|
| 105 |
-
task = [s["task"] for s in samples]
|
| 106 |
-
id = [s["id"] for s in samples]
|
| 107 |
-
index = [s["index"] for s in samples]
|
| 108 |
-
|
| 109 |
-
return {
|
| 110 |
-
"raw_wav": raw_wav,
|
| 111 |
-
"padding_mask": paddding_mask,
|
| 112 |
-
"text": text,
|
| 113 |
-
"task": task,
|
| 114 |
-
"id": id,
|
| 115 |
-
"prompt": prompt,
|
| 116 |
-
"index": index,
|
| 117 |
-
"audio_chunk_sizes": audio_chunk_sizes,
|
| 118 |
-
}
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
class NatureLMDataset(Dataset):
|
| 122 |
-
def __init__(
|
| 123 |
-
self,
|
| 124 |
-
ann_path: str | Path,
|
| 125 |
-
*,
|
| 126 |
-
max_length_seconds: int = 10,
|
| 127 |
-
cropping: Literal["random", "start"] | None = "random",
|
| 128 |
-
noise_prob: float = 0.0,
|
| 129 |
-
noise_dirs: list[str] | list[Path] | None = None,
|
| 130 |
-
low_snr: float = -5,
|
| 131 |
-
high_snr: float = 20,
|
| 132 |
-
time_scale_prob: float = 0.0,
|
| 133 |
-
time_scale: float = 1.2,
|
| 134 |
-
seed: int = 0,
|
| 135 |
-
mixup_prob: float = 0.0,
|
| 136 |
-
mixup_count: int = 3,
|
| 137 |
-
use_augmentation: bool = False,
|
| 138 |
-
mask_audio_prob: float = 0.0,
|
| 139 |
-
):
|
| 140 |
-
super().__init__()
|
| 141 |
-
|
| 142 |
-
ann_path = Path(ann_path)
|
| 143 |
-
|
| 144 |
-
if not ann_path.exists():
|
| 145 |
-
raise FileNotFoundError(f"Dataset file {ann_path} not found")
|
| 146 |
-
|
| 147 |
-
try:
|
| 148 |
-
with open(ann_path, "r") as f:
|
| 149 |
-
data = json.load(f)
|
| 150 |
-
self.annotation = data["annotation"]
|
| 151 |
-
except (json.JSONDecodeError, KeyError):
|
| 152 |
-
with open(ann_path, "r") as f:
|
| 153 |
-
self.annotation = [json.loads(line) for line in f]
|
| 154 |
-
|
| 155 |
-
#### mixup related variables
|
| 156 |
-
### hash table for tasks to sample the tasks faster
|
| 157 |
-
self.tasks = defaultdict(list)
|
| 158 |
-
for i, ann in enumerate(self.annotation):
|
| 159 |
-
if "task" in ann and "text" in ann and ann["text"] != "None" and "path" in ann:
|
| 160 |
-
self.tasks[ann["task"]].append(i)
|
| 161 |
-
|
| 162 |
-
self.mixup_tasks = {
|
| 163 |
-
task: []
|
| 164 |
-
for task in self.tasks.keys()
|
| 165 |
-
if task.endswith("simple-detection")
|
| 166 |
-
or task.endswith("multiple-detection") # Add more tasks after validating prompt mixing.
|
| 167 |
-
or task.endswith("sci-detection-random")
|
| 168 |
-
or task.endswith("common-detection-random")
|
| 169 |
-
}
|
| 170 |
-
for k in self.mixup_tasks.keys():
|
| 171 |
-
# whichever the base, only mix open-ended tasks.
|
| 172 |
-
if "sci-" in k:
|
| 173 |
-
self.mixup_tasks[k] = [
|
| 174 |
-
task
|
| 175 |
-
for task in self.mixup_tasks.keys()
|
| 176 |
-
if task.endswith("sci-simple-detection") or task.endswith("sci-multiple-detection")
|
| 177 |
-
]
|
| 178 |
-
elif "common-" in k:
|
| 179 |
-
self.mixup_tasks[k] = [
|
| 180 |
-
task
|
| 181 |
-
for task in self.mixup_tasks.keys()
|
| 182 |
-
if task.endswith("common-simple-detection") or task.endswith("common-multiple-detection")
|
| 183 |
-
]
|
| 184 |
-
else:
|
| 185 |
-
self.mixup_tasks[k] = [task for task in self.mixup_tasks.keys() if "common-" in task]
|
| 186 |
-
|
| 187 |
-
# print("num annotations", len(self.annotation))
|
| 188 |
-
# print("annotation 0", self.annotation[0])
|
| 189 |
-
# self.annotation = [a for a in self.annotation if "task" in a and "detection" not in a["task"]] # no detection... :(
|
| 190 |
-
self.max_length_seconds = max_length_seconds
|
| 191 |
-
self.cropping = cropping
|
| 192 |
-
self.use_augmentation = use_augmentation
|
| 193 |
-
|
| 194 |
-
### noise augmentation
|
| 195 |
-
self.rng = random.Random(seed)
|
| 196 |
-
self.rngnp = np.random.default_rng(seed=seed)
|
| 197 |
-
self.noise_dirs = noise_dirs
|
| 198 |
-
self.noise_prob = noise_prob
|
| 199 |
-
self.noise_files = []
|
| 200 |
-
self.low_snr = low_snr
|
| 201 |
-
self.high_snr = high_snr
|
| 202 |
-
self.mask_audio_prob = mask_audio_prob
|
| 203 |
-
if noise_dirs is not None and len(self.noise_dirs) > 0 and self.use_augmentation:
|
| 204 |
-
for noise_dir in noise_dirs:
|
| 205 |
-
noise_from_dir = glob.glob(os.path.join(noise_dir, "*.wav"))
|
| 206 |
-
if len(noise_from_dir) < 3000:
|
| 207 |
-
noise_from_dir = noise_from_dir * 3
|
| 208 |
-
print("noise files from dir", noise_dir, len(noise_from_dir))
|
| 209 |
-
self.noise_files.extend(noise_from_dir)
|
| 210 |
-
|
| 211 |
-
### mixup augmentation
|
| 212 |
-
self.mixup_prob = mixup_prob
|
| 213 |
-
self.mixup_count = mixup_count
|
| 214 |
-
# ### time scale augmentation
|
| 215 |
-
self.time_scale = time_scale
|
| 216 |
-
self.time_scale_prob = time_scale_prob
|
| 217 |
-
# tasks = set([annotation["task"] if "task" in annotation else "empty" for annotation in self.annotation])
|
| 218 |
-
print(":::all tasks:::", self.tasks.keys())
|
| 219 |
-
print("num examples", len(self.annotation))
|
| 220 |
-
|
| 221 |
-
def __len__(self):
|
| 222 |
-
return len(self.annotation)
|
| 223 |
-
|
| 224 |
-
def collater(self, samples):
|
| 225 |
-
return collater(samples)
|
| 226 |
-
|
| 227 |
-
def load_audio(self, audio_path, shift_allowed: bool, noise_allowed: bool):
|
| 228 |
-
audio, sr = sf.read(audio_path)
|
| 229 |
-
# assert sr == 16000
|
| 230 |
-
if sr != 16000:
|
| 231 |
-
print("other sr!", sr, audio_path)
|
| 232 |
-
if len(audio.shape) == 2: # stereo to mono
|
| 233 |
-
audio = audio.mean(axis=1)
|
| 234 |
-
|
| 235 |
-
### time scale augmentation
|
| 236 |
-
if self.use_augmentation and self.rng.random() < self.time_scale_prob and self.time_scale > 0 and shift_allowed:
|
| 237 |
-
# print(f"{index} scaling audio")
|
| 238 |
-
# write_example_to_file(os.path.basename(ann["path"]), audio[: sr * self.max_length_seconds] )
|
| 239 |
-
audio = time_scale(torch.tensor(audio), scale=self.time_scale, rngnp=self.rngnp).numpy()
|
| 240 |
-
# write_example_to_file(os.path.basename(ann["path"]), audio[: sr * self.max_length_seconds] , suffix='_ts')
|
| 241 |
-
|
| 242 |
-
# Randomly crop a max_length_seconds window if audio is longer than 10 seconds
|
| 243 |
-
if len(audio) > sr * self.max_length_seconds and self.cropping == "random":
|
| 244 |
-
max_start = len(audio) - sr * self.max_length_seconds
|
| 245 |
-
start = random.randint(0, max_start)
|
| 246 |
-
audio = audio[start : start + sr * self.max_length_seconds]
|
| 247 |
-
else: # no random cropping
|
| 248 |
-
audio = audio[: sr * self.max_length_seconds] # Truncate audio to at most max_length_seconds
|
| 249 |
-
|
| 250 |
-
### noise augmentation
|
| 251 |
-
audio = torch.tensor(audio)
|
| 252 |
-
### noise augmentation
|
| 253 |
-
if (
|
| 254 |
-
self.use_augmentation
|
| 255 |
-
and self.rng.random() < self.noise_prob
|
| 256 |
-
and len(self.noise_files) > 0
|
| 257 |
-
and noise_allowed
|
| 258 |
-
):
|
| 259 |
-
# write_example_to_file(os.path.basename(ann["path"]), audio)
|
| 260 |
-
# print(f"{index} adding noise")
|
| 261 |
-
noise_file = self.rng.choice(self.noise_files)
|
| 262 |
-
if not os.path.exists(noise_file):
|
| 263 |
-
print(f"Warning: noise file {noise_file} does not exist")
|
| 264 |
-
else:
|
| 265 |
-
noise_audio, noise_sr = sf.read(noise_file)
|
| 266 |
-
assert noise_sr == 16000
|
| 267 |
-
if len(noise_audio.shape) == 2:
|
| 268 |
-
noise_audio = noise_audio.mean(axis=1)
|
| 269 |
-
|
| 270 |
-
noise_audio = torch.tensor(noise_audio)
|
| 271 |
-
|
| 272 |
-
### repeat or trim to the audio size
|
| 273 |
-
if len(audio) > len(noise_audio):
|
| 274 |
-
if len(noise_audio) == 0:
|
| 275 |
-
print(
|
| 276 |
-
"----- Warning: Noise audio length is zero. ---------- ",
|
| 277 |
-
noise_file,
|
| 278 |
-
)
|
| 279 |
-
# Option 1: Skip noise augmentation by setting noise_audio to zero
|
| 280 |
-
noise_audio = torch.zeros_like(audio)
|
| 281 |
-
else:
|
| 282 |
-
nrepeats = int(np.maximum(2, np.ceil(len(audio) / len(noise_audio))))
|
| 283 |
-
noise_audio = noise_audio.repeat(nrepeats)
|
| 284 |
-
### Randomly crop the noise file if it is too long
|
| 285 |
-
if len(noise_audio) > len(audio):
|
| 286 |
-
max_start = len(noise_audio) - len(audio)
|
| 287 |
-
start = random.randint(0, max_start)
|
| 288 |
-
noise_audio = noise_audio[start : start + len(audio)]
|
| 289 |
-
|
| 290 |
-
### remix with specified snr
|
| 291 |
-
snr = self.rngnp.uniform(self.low_snr, self.high_snr)
|
| 292 |
-
snr = torch.tensor([snr])
|
| 293 |
-
noise_audio = snr_scale(audio, noise_audio, snr)
|
| 294 |
-
audio = audio + noise_audio
|
| 295 |
-
|
| 296 |
-
# write_example_to_file(os.path.basename(audio_path), audio, suffix='_noise')
|
| 297 |
-
if len(audio) > self.max_length_seconds * sr:
|
| 298 |
-
print("long audio", len(audio), len(noise_audio))
|
| 299 |
-
audio = audio[: self.max_length_seconds * sr]
|
| 300 |
-
|
| 301 |
-
# pad all audios to max_len_seconds in _getitem_ to ensure no padding inconsistencies.
|
| 302 |
-
if len(audio) < sr * self.max_length_seconds:
|
| 303 |
-
pad_size = sr * self.max_length_seconds - len(audio)
|
| 304 |
-
audio = torch.nn.functional.pad(audio, (0, pad_size))
|
| 305 |
-
|
| 306 |
-
audio = torch.clamp(audio, -1.0, 1.0)
|
| 307 |
-
|
| 308 |
-
return audio
|
| 309 |
-
|
| 310 |
-
def _mix_labels(self, text, text_to_mix):
|
| 311 |
-
"""
|
| 312 |
-
Given two comma-separated label strings (e.g., "gorilla, zebra"),
|
| 313 |
-
combine them without introducing duplicates. If either is "None",
|
| 314 |
-
return the other as-is (unless both are "None").
|
| 315 |
-
"""
|
| 316 |
-
# If `text_to_mix` is explicitly "None", just return `text`.
|
| 317 |
-
if text_to_mix == "None":
|
| 318 |
-
return text
|
| 319 |
-
|
| 320 |
-
# If `text` is explicitly "None", just return `text_to_mix`.
|
| 321 |
-
if text == "None":
|
| 322 |
-
return text_to_mix
|
| 323 |
-
|
| 324 |
-
# Split both strings by comma, stripping whitespace
|
| 325 |
-
text_list = [item.strip() for item in text.split(",") if item.strip()]
|
| 326 |
-
text_to_mix_list = [item.strip() for item in text_to_mix.split(",") if item.strip()]
|
| 327 |
-
|
| 328 |
-
# Deduplicate: add only new items from text_to_mix_list
|
| 329 |
-
combined_set = set(text_list)
|
| 330 |
-
for item in text_to_mix_list:
|
| 331 |
-
if item not in combined_set:
|
| 332 |
-
text_list.append(item)
|
| 333 |
-
combined_set.add(item)
|
| 334 |
-
|
| 335 |
-
# If there's nothing left after deduplication, return "None".
|
| 336 |
-
if not text_list:
|
| 337 |
-
return "None"
|
| 338 |
-
|
| 339 |
-
# Rejoin them into a comma-separated string
|
| 340 |
-
return ", ".join(text_list)
|
| 341 |
-
|
| 342 |
-
def _mix_prompts(self, text, text_to_mix, prompt):
|
| 343 |
-
"""
|
| 344 |
-
If the prompt is in the form:
|
| 345 |
-
"Which of these, if any, are present in the audio recording? option1, option2, ..."
|
| 346 |
-
|
| 347 |
-
1. Parse out the question (before '?') and the list of prompt choices (after '?').
|
| 348 |
-
2. Convert both `text` and `text_to_mix` into lists, checking for items not in the prompt.
|
| 349 |
-
3. Append any missing answers to the prompt choices.
|
| 350 |
-
4. Shuffle the choices.
|
| 351 |
-
5. Reassemble and return the new prompt.
|
| 352 |
-
|
| 353 |
-
If the prompt does not follow the expected structure, it is returned unmodified.
|
| 354 |
-
"""
|
| 355 |
-
# Split into two parts: question + choices
|
| 356 |
-
splitted = prompt.split("?")
|
| 357 |
-
if len(splitted) != 2:
|
| 358 |
-
# If we don't have exactly one question mark segment, just return the original prompt
|
| 359 |
-
return prompt
|
| 360 |
-
|
| 361 |
-
question = splitted[0].strip()
|
| 362 |
-
potential_choices_str = splitted[1].strip()
|
| 363 |
-
|
| 364 |
-
# Split the prompt choices
|
| 365 |
-
if not potential_choices_str:
|
| 366 |
-
prompt_choices = []
|
| 367 |
-
else:
|
| 368 |
-
prompt_choices = [c.strip() for c in potential_choices_str.split(",") if c.strip()]
|
| 369 |
-
|
| 370 |
-
# Parse `text`
|
| 371 |
-
text_list = [item.strip() for item in text.split(",") if item.strip()]
|
| 372 |
-
|
| 373 |
-
# Parse `text_to_mix`
|
| 374 |
-
text_to_mix_list = [item.strip() for item in text_to_mix.split(",") if item.strip()]
|
| 375 |
-
|
| 376 |
-
# Add any new items from text_list to the prompt
|
| 377 |
-
for item in text_list:
|
| 378 |
-
if item not in prompt_choices:
|
| 379 |
-
prompt_choices.append(item)
|
| 380 |
-
|
| 381 |
-
# Add any new items from text_to_mix_list to the prompt
|
| 382 |
-
for item in text_to_mix_list:
|
| 383 |
-
if item not in prompt_choices:
|
| 384 |
-
prompt_choices.append(item)
|
| 385 |
-
|
| 386 |
-
# Shuffle consistently with self.rng
|
| 387 |
-
self.rng.shuffle(prompt_choices)
|
| 388 |
-
|
| 389 |
-
# Reassemble
|
| 390 |
-
new_prompt = question + "? " + ", ".join(prompt_choices)
|
| 391 |
-
return new_prompt
|
| 392 |
-
|
| 393 |
-
def _apply_mixup(self, prompt, audio, text, task, filename=None):
|
| 394 |
-
# mixup_applied = False
|
| 395 |
-
if (
|
| 396 |
-
self.use_augmentation and self.rng.random() < self.mixup_prob and task in self.mixup_tasks
|
| 397 |
-
# and text != "None" # Allow complex 'None' examples.
|
| 398 |
-
):
|
| 399 |
-
# write_example_to_file(os.path.basename(ann["path"]), audio)
|
| 400 |
-
# print(f"{index} mixing up")
|
| 401 |
-
mixup_indices = []
|
| 402 |
-
for pair_task in self.mixup_tasks[task]:
|
| 403 |
-
mixup_indices.extend(self.tasks[pair_task])
|
| 404 |
-
# mixup_indices = mixup_indices.remove(index)
|
| 405 |
-
|
| 406 |
-
if len(mixup_indices) == 0:
|
| 407 |
-
print("No mixup partner found")
|
| 408 |
-
else:
|
| 409 |
-
### choose n_mixup random partners
|
| 410 |
-
n_mixup = self.rng.randint(1, self.mixup_count)
|
| 411 |
-
mixup_indices = self.rng.sample(mixup_indices, n_mixup)
|
| 412 |
-
# print(f"Mixing up with indices {mixup_indices}")
|
| 413 |
-
for mixup_index in mixup_indices:
|
| 414 |
-
mixup_ann = self.annotation[mixup_index]
|
| 415 |
-
mixup_audio, _ = sf.read(mixup_ann["path"])
|
| 416 |
-
if len(mixup_audio.shape) == 2:
|
| 417 |
-
mixup_audio = mixup_audio.mean(axis=1)
|
| 418 |
-
mixup_audio = mixup_audio[: len(audio)]
|
| 419 |
-
if len(mixup_audio) < len(audio):
|
| 420 |
-
pad_size = len(audio) - len(mixup_audio)
|
| 421 |
-
mixup_audio = np.pad(mixup_audio, (0, pad_size), mode="constant")
|
| 422 |
-
mixup_audio = torch.from_numpy(mixup_audio).float()
|
| 423 |
-
lam = np.clip(self.rngnp.beta(1.0, 1.0), 0.1, 0.8)
|
| 424 |
-
|
| 425 |
-
# Mix the raw_wav
|
| 426 |
-
audio = lam * audio + (1 - lam) * mixup_audio
|
| 427 |
-
|
| 428 |
-
### Mix the prompts if the labels are given in prompts
|
| 429 |
-
if text in prompt:
|
| 430 |
-
prompt = self._mix_prompts(text, mixup_ann["text"], prompt)
|
| 431 |
-
|
| 432 |
-
### Mix the labels
|
| 433 |
-
text = self._mix_labels(text, mixup_ann["text"])
|
| 434 |
-
|
| 435 |
-
# mixup_applied = True
|
| 436 |
-
|
| 437 |
-
# DEBUG: If mixup was actually applied, save the final audio
|
| 438 |
-
# if mixup_applied and filename is not None:
|
| 439 |
-
# # Just add a suffix to the original filename to indicate mixup
|
| 440 |
-
# base_filename = os.path.basename(filename)
|
| 441 |
-
# write_example_to_file(
|
| 442 |
-
# base_filename=base_filename,
|
| 443 |
-
# audio=audio,
|
| 444 |
-
# sr=16000,
|
| 445 |
-
# suffix="_mixup",
|
| 446 |
-
# save_dir="mixup_outputs"
|
| 447 |
-
# )
|
| 448 |
-
# print(f"mixup for {filename}::: prompt {prompt} label {text}")
|
| 449 |
-
|
| 450 |
-
return prompt, audio, text
|
| 451 |
-
|
| 452 |
-
def _load_noise(self, shift_allowed: bool):
|
| 453 |
-
noise_file = self.rng.choice(self.noise_files)
|
| 454 |
-
noise_audio, noise_sr = sf.read(noise_file)
|
| 455 |
-
assert noise_sr == 16000, f"Expected noise sample rate 16000, got {noise_sr}"
|
| 456 |
-
if len(noise_audio.shape) == 2:
|
| 457 |
-
noise_audio = noise_audio.mean(axis=1)
|
| 458 |
-
|
| 459 |
-
# Time scale augmentation if applicable
|
| 460 |
-
if self.use_augmentation and self.rng.random() < self.time_scale_prob and self.time_scale > 0 and shift_allowed:
|
| 461 |
-
noise_audio = time_scale(torch.tensor(noise_audio), scale=self.time_scale, rngnp=self.rngnp).numpy()
|
| 462 |
-
|
| 463 |
-
# Randomly crop or pad to match max_length_seconds
|
| 464 |
-
if len(noise_audio) > self.max_length_seconds * 16000 and self.cropping == "random":
|
| 465 |
-
max_start = len(noise_audio) - self.max_length_seconds * 16000
|
| 466 |
-
start = random.randint(0, max_start)
|
| 467 |
-
noise_audio = noise_audio[start : start + self.max_length_seconds * 16000]
|
| 468 |
-
else:
|
| 469 |
-
noise_audio = noise_audio[: self.max_length_seconds * 16000]
|
| 470 |
-
|
| 471 |
-
# Pad if needed
|
| 472 |
-
if len(noise_audio) < self.max_length_seconds * 16000:
|
| 473 |
-
pad_size = self.max_length_seconds * 16000 - len(noise_audio)
|
| 474 |
-
noise_audio = np.pad(noise_audio, (0, pad_size), mode="constant")
|
| 475 |
-
|
| 476 |
-
noise_audio = torch.tensor(noise_audio).float()
|
| 477 |
-
noise_audio = torch.clamp(noise_audio, -1.0, 1.0)
|
| 478 |
-
return noise_audio
|
| 479 |
-
|
| 480 |
-
def __getitem__(self, index):
|
| 481 |
-
ann = self.annotation[index]
|
| 482 |
-
# print("loading audio::", ann)
|
| 483 |
-
shift_allowed = "pitch" not in ann.get("task", "")
|
| 484 |
-
noise_allowed = (
|
| 485 |
-
"/A/" not in ann.get("path", "")
|
| 486 |
-
and "-qa" not in ann.get("task", "")
|
| 487 |
-
and "icl" not in ann.get("task", "")
|
| 488 |
-
and "caption" not in ann.get("task", "")
|
| 489 |
-
and "animal-instructions" not in ann.get("task", "")
|
| 490 |
-
)
|
| 491 |
-
|
| 492 |
-
task = ann.get("task", "asr")
|
| 493 |
-
text = ann["text"]
|
| 494 |
-
prompt = ann["prompt"]
|
| 495 |
-
|
| 496 |
-
replace_with_noise = (
|
| 497 |
-
self.use_augmentation
|
| 498 |
-
and task.endswith("detection")
|
| 499 |
-
and self.rng.random() < self.mask_audio_prob
|
| 500 |
-
and len(self.noise_files) > 0
|
| 501 |
-
)
|
| 502 |
-
|
| 503 |
-
if replace_with_noise:
|
| 504 |
-
# Replace audio with noise
|
| 505 |
-
audio = self._load_noise(shift_allowed)
|
| 506 |
-
audios = [audio]
|
| 507 |
-
text = "None"
|
| 508 |
-
|
| 509 |
-
else:
|
| 510 |
-
if "path" in ann and ann["path"] is not None:
|
| 511 |
-
audio = self.load_audio(ann["path"], shift_allowed, noise_allowed)
|
| 512 |
-
audios = [audio]
|
| 513 |
-
else:
|
| 514 |
-
audios = [self.load_audio(p, shift_allowed, noise_allowed) for p in ann["files"]]
|
| 515 |
-
|
| 516 |
-
if len(audios) == 1:
|
| 517 |
-
prompt, mixed_audio, text = self._apply_mixup(prompt, audio, text, task, filename=ann["path"])
|
| 518 |
-
audios = [mixed_audio]
|
| 519 |
-
|
| 520 |
-
return {
|
| 521 |
-
"raw_wav": audios,
|
| 522 |
-
"text": text,
|
| 523 |
-
"task": task,
|
| 524 |
-
"id": ann.get("path") or ";".join(ann["files"]),
|
| 525 |
-
"prompt": prompt,
|
| 526 |
-
"index": index, # track which element for eval output
|
| 527 |
-
"ann": ann, # Include annotation for mixup
|
| 528 |
-
}
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
if __name__ == "__main__":
|
| 532 |
-
dataset = NatureLMDataset(
|
| 533 |
-
ann_path="/home/ubuntu/foundation-model-storage/foundation-model-data/data/compiled-datasets/v1/s2_eval_valid.jsonl",
|
| 534 |
-
noise_dirs=["resource/audio_demo"],
|
| 535 |
-
max_length_seconds=10,
|
| 536 |
-
use_augmentation=True,
|
| 537 |
-
mixup_prob=1.0, # For demonstration, force mixup if possible
|
| 538 |
-
mixup_count=2, # Up to 2 mixup partners
|
| 539 |
-
mask_audio_prob=0.2,
|
| 540 |
-
seed=42,
|
| 541 |
-
noise_prob=0.5,
|
| 542 |
-
)
|
| 543 |
-
|
| 544 |
-
# Process just a few to see the saved mixups
|
| 545 |
-
for i in range(300):
|
| 546 |
-
sample = dataset[i]
|
| 547 |
-
# print("Final text:", sample["text"])
|
| 548 |
-
# print("Final prompt:", sample["prompt"])
|
| 549 |
-
# print("-" * 40)
|
| 550 |
-
print("Done! Look in 'debug_outputs' folder for saved mixup files.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
NatureLM/dist_utils.py
DELETED
|
@@ -1,109 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Adapted from salesforce@LAVIS. Below is the original copyright:
|
| 3 |
-
Copyright (c) 2022, salesforce.com, inc.
|
| 4 |
-
All rights reserved.
|
| 5 |
-
SPDX-License-Identifier: BSD-3-Clause
|
| 6 |
-
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
| 7 |
-
"""
|
| 8 |
-
|
| 9 |
-
import datetime
|
| 10 |
-
import functools
|
| 11 |
-
import os
|
| 12 |
-
|
| 13 |
-
import torch
|
| 14 |
-
import torch.distributed as dist
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
def setup_for_distributed(is_master):
|
| 18 |
-
"""
|
| 19 |
-
This function disables printing when not in master process
|
| 20 |
-
"""
|
| 21 |
-
import builtins as __builtin__
|
| 22 |
-
|
| 23 |
-
builtin_print = __builtin__.print
|
| 24 |
-
|
| 25 |
-
def print(*args, **kwargs):
|
| 26 |
-
force = kwargs.pop("force", False)
|
| 27 |
-
if is_master or force:
|
| 28 |
-
builtin_print(*args, **kwargs)
|
| 29 |
-
|
| 30 |
-
__builtin__.print = print
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
def is_dist_avail_and_initialized():
|
| 34 |
-
if not dist.is_available():
|
| 35 |
-
return False
|
| 36 |
-
if not dist.is_initialized():
|
| 37 |
-
return False
|
| 38 |
-
return True
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
def get_world_size():
|
| 42 |
-
if not is_dist_avail_and_initialized():
|
| 43 |
-
return 1
|
| 44 |
-
return dist.get_world_size()
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
def get_rank():
|
| 48 |
-
if not is_dist_avail_and_initialized():
|
| 49 |
-
return 0
|
| 50 |
-
return dist.get_rank()
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
def is_main_process():
|
| 54 |
-
return get_rank() == 0
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
def init_distributed_mode(args):
|
| 58 |
-
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
| 59 |
-
args.rank = int(os.environ["RANK"])
|
| 60 |
-
args.world_size = int(os.environ["WORLD_SIZE"])
|
| 61 |
-
args.gpu = int(os.environ["LOCAL_RANK"])
|
| 62 |
-
elif "SLURM_PROCID" in os.environ:
|
| 63 |
-
args.rank = int(os.environ["SLURM_PROCID"])
|
| 64 |
-
args.gpu = args.rank % torch.cuda.device_count()
|
| 65 |
-
else:
|
| 66 |
-
print("Not using distributed mode")
|
| 67 |
-
args.use_distributed = False
|
| 68 |
-
return
|
| 69 |
-
|
| 70 |
-
args.use_distributed = True
|
| 71 |
-
|
| 72 |
-
torch.cuda.set_device(args.gpu)
|
| 73 |
-
print(
|
| 74 |
-
"| distributed init (rank {}, world {}): {}".format(args.rank, args.world_size, args.dist_url),
|
| 75 |
-
flush=True,
|
| 76 |
-
)
|
| 77 |
-
torch.distributed.init_process_group(
|
| 78 |
-
backend=args.dist_backend,
|
| 79 |
-
init_method=args.dist_url,
|
| 80 |
-
world_size=args.world_size,
|
| 81 |
-
rank=args.rank,
|
| 82 |
-
timeout=datetime.timedelta(days=365), # allow auto-downloading and de-compressing
|
| 83 |
-
)
|
| 84 |
-
torch.distributed.barrier()
|
| 85 |
-
setup_for_distributed(args.rank == 0)
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
def get_dist_info():
|
| 89 |
-
if torch.__version__ < "1.0":
|
| 90 |
-
initialized = dist._initialized
|
| 91 |
-
else:
|
| 92 |
-
initialized = dist.is_initialized()
|
| 93 |
-
if initialized:
|
| 94 |
-
rank = dist.get_rank()
|
| 95 |
-
world_size = dist.get_world_size()
|
| 96 |
-
else: # non-distributed training
|
| 97 |
-
rank = 0
|
| 98 |
-
world_size = 1
|
| 99 |
-
return rank, world_size
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
def main_process(func):
|
| 103 |
-
@functools.wraps(func)
|
| 104 |
-
def wrapper(*args, **kwargs):
|
| 105 |
-
rank, _ = get_dist_info()
|
| 106 |
-
if rank == 0:
|
| 107 |
-
return func(*args, **kwargs)
|
| 108 |
-
|
| 109 |
-
return wrapper
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
NatureLM/infer.py
CHANGED
|
@@ -5,7 +5,7 @@ from pathlib import Path
|
|
| 5 |
|
| 6 |
import numpy as np
|
| 7 |
import pandas as pd
|
| 8 |
-
import
|
| 9 |
import torch
|
| 10 |
|
| 11 |
from NatureLM.config import Config
|
|
@@ -16,10 +16,15 @@ from NatureLM.utils import move_to_device
|
|
| 16 |
_MAX_LENGTH_SECONDS = 10
|
| 17 |
_MIN_CHUNK_LENGTH_SECONDS = 0.5
|
| 18 |
_SAMPLE_RATE = 16000 # Assuming the model uses a sample rate of 16kHz
|
| 19 |
-
_AUDIO_FILE_EXTENSIONS = [
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
def load_model_and_config(
|
|
@@ -32,7 +37,9 @@ def load_model_and_config(
|
|
| 32 |
model = NatureLM.from_pretrained("EarthSpeciesProject/NatureLM-audio")
|
| 33 |
model = model.to(device).eval()
|
| 34 |
model.llama_tokenizer.pad_token_id = model.llama_tokenizer.eos_token_id
|
| 35 |
-
model.llama_model.generation_config.pad_token_id =
|
|
|
|
|
|
|
| 36 |
|
| 37 |
cfg = Config.from_sources(cfg_path)
|
| 38 |
return model, cfg
|
|
@@ -53,7 +60,7 @@ def sliding_window_inference(
|
|
| 53 |
hop_length_seconds: float = 10.0,
|
| 54 |
input_sr: int = _SAMPLE_RATE,
|
| 55 |
device: str = _DEVICE,
|
| 56 |
-
) -> str:
|
| 57 |
"""Run inference on a long audio file using sliding window approach.
|
| 58 |
|
| 59 |
Args:
|
|
@@ -73,7 +80,7 @@ def sliding_window_inference(
|
|
| 73 |
ValueError: If the audio file is too short or if the audio file path is invalid.
|
| 74 |
"""
|
| 75 |
if isinstance(audio, str) or isinstance(audio, Path):
|
| 76 |
-
audio_array, input_sr =
|
| 77 |
elif isinstance(audio, np.ndarray):
|
| 78 |
audio_array = audio
|
| 79 |
print(f"Using provided sample rate: {input_sr}")
|
|
@@ -86,13 +93,16 @@ def sliding_window_inference(
|
|
| 86 |
|
| 87 |
# Do initial check that the audio is long enough
|
| 88 |
if audio_array.shape[-1] < int(_MIN_CHUNK_LENGTH_SECONDS * input_sr):
|
| 89 |
-
raise ValueError(
|
|
|
|
|
|
|
| 90 |
|
| 91 |
start = 0
|
| 92 |
stride = int(hop_length_seconds * input_sr)
|
| 93 |
window_length = int(window_length_seconds * input_sr)
|
|
|
|
| 94 |
|
| 95 |
-
output =
|
| 96 |
while True:
|
| 97 |
chunk = audio_array[start : start + window_length]
|
| 98 |
if chunk.shape[-1] < int(_MIN_CHUNK_LENGTH_SECONDS * input_sr):
|
|
@@ -113,8 +123,16 @@ def sliding_window_inference(
|
|
| 113 |
prediction: str = model.generate(input_to_model, cfg.generate, prompt_list)[0]
|
| 114 |
|
| 115 |
# Post-process the prediction
|
| 116 |
-
prediction = output_template(prediction, start / input_sr, (start + window_length) / input_sr)
|
| 117 |
-
output += prediction
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
# Move the window
|
| 120 |
start += stride
|
|
@@ -128,7 +146,9 @@ def sliding_window_inference(
|
|
| 128 |
class Pipeline:
|
| 129 |
"""Pipeline for running NatureLM-audio inference on a list of audio files or audio arrays"""
|
| 130 |
|
| 131 |
-
def __init__(
|
|
|
|
|
|
|
| 132 |
self.cfg_path = cfg_path
|
| 133 |
|
| 134 |
# Load model and config
|
|
@@ -139,7 +159,9 @@ class Pipeline:
|
|
| 139 |
# Download model from hub
|
| 140 |
self.model, self.cfg = load_model_and_config(cfg_path)
|
| 141 |
|
| 142 |
-
self.processor = NatureLMAudioProcessor(
|
|
|
|
|
|
|
| 143 |
|
| 144 |
def __call__(
|
| 145 |
self,
|
|
@@ -149,6 +171,7 @@ class Pipeline:
|
|
| 149 |
hop_length_seconds: float = 10.0,
|
| 150 |
input_sample_rate: int = _SAMPLE_RATE,
|
| 151 |
verbose: bool = False,
|
|
|
|
| 152 |
) -> list[str]:
|
| 153 |
"""Run inference on a list of audio file paths or a single audio file with a
|
| 154 |
single query or a list of queries. If multiple queries are provided,
|
|
@@ -165,18 +188,11 @@ class Pipeline:
|
|
| 165 |
Defaults to False.
|
| 166 |
|
| 167 |
Returns:
|
| 168 |
-
|
|
|
|
| 169 |
|
| 170 |
Raises:
|
| 171 |
ValueError: If the number of audio files and queries do not match.
|
| 172 |
-
|
| 173 |
-
Example:
|
| 174 |
-
>>> pipeline = Pipeline()
|
| 175 |
-
>>> audios = ["assets/nri-GreenTreeFrogEvergladesNP.mp3"]
|
| 176 |
-
>>> queries = ["Which species is this? Provide the common name."]
|
| 177 |
-
>>> results = pipeline(audios, queries)
|
| 178 |
-
>>> print(results)
|
| 179 |
-
['#0.00s - 10.00s#: Green Treefrog\n']
|
| 180 |
"""
|
| 181 |
if isinstance(audios, str) or isinstance(audios, Path):
|
| 182 |
audios = [audios]
|
|
@@ -189,7 +205,10 @@ class Pipeline:
|
|
| 189 |
|
| 190 |
# Run inference
|
| 191 |
results = []
|
| 192 |
-
|
|
|
|
|
|
|
|
|
|
| 193 |
output = sliding_window_inference(
|
| 194 |
audio,
|
| 195 |
query,
|
|
@@ -209,21 +228,38 @@ class Pipeline:
|
|
| 209 |
def parse_args() -> argparse.Namespace:
|
| 210 |
parser = argparse.ArgumentParser("Run NatureLM-audio inference")
|
| 211 |
parser.add_argument(
|
| 212 |
-
"-a",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
)
|
| 214 |
-
parser.add_argument("-q", "--query", type=str, required=True, help="Query for the model")
|
| 215 |
parser.add_argument(
|
| 216 |
"--cfg-path",
|
| 217 |
type=str,
|
| 218 |
default="configs/inference.yml",
|
| 219 |
help="Path to the configuration file for the model",
|
| 220 |
)
|
| 221 |
-
parser.add_argument("--output_path", type=str, default="inference_output.jsonl", help="Output path for the results")
|
| 222 |
parser.add_argument(
|
| 223 |
-
"--
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
)
|
| 225 |
parser.add_argument(
|
| 226 |
-
"--hop_length_seconds",
|
|
|
|
|
|
|
|
|
|
| 227 |
)
|
| 228 |
args = parser.parse_args()
|
| 229 |
|
|
@@ -261,7 +297,9 @@ def main(
|
|
| 261 |
audio_path = Path(audio_path)
|
| 262 |
if audio_path.is_dir():
|
| 263 |
audio_paths = []
|
| 264 |
-
print(
|
|
|
|
|
|
|
| 265 |
for ext in _AUDIO_FILE_EXTENSIONS:
|
| 266 |
audio_paths.extend(list(audio_path.rglob(f"*{ext}")))
|
| 267 |
|
|
@@ -278,18 +316,30 @@ def main(
|
|
| 278 |
if not query:
|
| 279 |
raise ValueError("Query cannot be empty")
|
| 280 |
if not audio_paths:
|
| 281 |
-
raise ValueError(
|
|
|
|
|
|
|
| 282 |
|
| 283 |
# Load model and config
|
| 284 |
model, cfg = load_model_and_config(cfg_path)
|
| 285 |
|
| 286 |
# Load audio processor
|
| 287 |
-
processor = NatureLMAudioProcessor(
|
|
|
|
|
|
|
| 288 |
|
| 289 |
# Run inference
|
| 290 |
results = {"audio_path": [], "output": []}
|
| 291 |
for path in audio_paths:
|
| 292 |
-
output = sliding_window_inference(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
results["audio_path"].append(str(path))
|
| 294 |
results["output"].append(output)
|
| 295 |
print(f"Processed {path}, model output:\n=======\n{output}\n=======\n")
|
|
|
|
| 5 |
|
| 6 |
import numpy as np
|
| 7 |
import pandas as pd
|
| 8 |
+
import librosa
|
| 9 |
import torch
|
| 10 |
|
| 11 |
from NatureLM.config import Config
|
|
|
|
| 16 |
_MAX_LENGTH_SECONDS = 10
|
| 17 |
_MIN_CHUNK_LENGTH_SECONDS = 0.5
|
| 18 |
_SAMPLE_RATE = 16000 # Assuming the model uses a sample rate of 16kHz
|
| 19 |
+
_AUDIO_FILE_EXTENSIONS = [
|
| 20 |
+
".wav",
|
| 21 |
+
".mp3",
|
| 22 |
+
".flac",
|
| 23 |
+
".ogg",
|
| 24 |
+
] # Add other audio file formats as needed
|
| 25 |
+
_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 26 |
+
__root_dir = Path(__file__).parent.parent
|
| 27 |
+
_DEFAULT_CONFIG_PATH = __root_dir / "configs" / "inference.yml"
|
| 28 |
|
| 29 |
|
| 30 |
def load_model_and_config(
|
|
|
|
| 37 |
model = NatureLM.from_pretrained("EarthSpeciesProject/NatureLM-audio")
|
| 38 |
model = model.to(device).eval()
|
| 39 |
model.llama_tokenizer.pad_token_id = model.llama_tokenizer.eos_token_id
|
| 40 |
+
model.llama_model.generation_config.pad_token_id = (
|
| 41 |
+
model.llama_tokenizer.pad_token_id
|
| 42 |
+
)
|
| 43 |
|
| 44 |
cfg = Config.from_sources(cfg_path)
|
| 45 |
return model, cfg
|
|
|
|
| 60 |
hop_length_seconds: float = 10.0,
|
| 61 |
input_sr: int = _SAMPLE_RATE,
|
| 62 |
device: str = _DEVICE,
|
| 63 |
+
) -> list[dict[str, any]]:
|
| 64 |
"""Run inference on a long audio file using sliding window approach.
|
| 65 |
|
| 66 |
Args:
|
|
|
|
| 80 |
ValueError: If the audio file is too short or if the audio file path is invalid.
|
| 81 |
"""
|
| 82 |
if isinstance(audio, str) or isinstance(audio, Path):
|
| 83 |
+
audio_array, input_sr = librosa.load(str(audio), sr=None, mono=False)
|
| 84 |
elif isinstance(audio, np.ndarray):
|
| 85 |
audio_array = audio
|
| 86 |
print(f"Using provided sample rate: {input_sr}")
|
|
|
|
| 93 |
|
| 94 |
# Do initial check that the audio is long enough
|
| 95 |
if audio_array.shape[-1] < int(_MIN_CHUNK_LENGTH_SECONDS * input_sr):
|
| 96 |
+
raise ValueError(
|
| 97 |
+
f"Audio is too short. Minimum length is {_MIN_CHUNK_LENGTH_SECONDS} seconds."
|
| 98 |
+
)
|
| 99 |
|
| 100 |
start = 0
|
| 101 |
stride = int(hop_length_seconds * input_sr)
|
| 102 |
window_length = int(window_length_seconds * input_sr)
|
| 103 |
+
window_id = 0
|
| 104 |
|
| 105 |
+
output = [] # Initialize output list
|
| 106 |
while True:
|
| 107 |
chunk = audio_array[start : start + window_length]
|
| 108 |
if chunk.shape[-1] < int(_MIN_CHUNK_LENGTH_SECONDS * input_sr):
|
|
|
|
| 123 |
prediction: str = model.generate(input_to_model, cfg.generate, prompt_list)[0]
|
| 124 |
|
| 125 |
# Post-process the prediction
|
| 126 |
+
# prediction = output_template(prediction, start / input_sr, (start + window_length) / input_sr)
|
| 127 |
+
# output += prediction
|
| 128 |
+
output.append(
|
| 129 |
+
{
|
| 130 |
+
"start_time": start / input_sr,
|
| 131 |
+
"end_time": (start + window_length) / input_sr,
|
| 132 |
+
"prediction": prediction,
|
| 133 |
+
"window_number": window_id,
|
| 134 |
+
}
|
| 135 |
+
)
|
| 136 |
|
| 137 |
# Move the window
|
| 138 |
start += stride
|
|
|
|
| 146 |
class Pipeline:
|
| 147 |
"""Pipeline for running NatureLM-audio inference on a list of audio files or audio arrays"""
|
| 148 |
|
| 149 |
+
def __init__(
|
| 150 |
+
self, model: NatureLM = None, cfg_path: str | Path = _DEFAULT_CONFIG_PATH
|
| 151 |
+
):
|
| 152 |
self.cfg_path = cfg_path
|
| 153 |
|
| 154 |
# Load model and config
|
|
|
|
| 159 |
# Download model from hub
|
| 160 |
self.model, self.cfg = load_model_and_config(cfg_path)
|
| 161 |
|
| 162 |
+
self.processor = NatureLMAudioProcessor(
|
| 163 |
+
sample_rate=_SAMPLE_RATE, max_length_seconds=_MAX_LENGTH_SECONDS
|
| 164 |
+
)
|
| 165 |
|
| 166 |
def __call__(
|
| 167 |
self,
|
|
|
|
| 171 |
hop_length_seconds: float = 10.0,
|
| 172 |
input_sample_rate: int = _SAMPLE_RATE,
|
| 173 |
verbose: bool = False,
|
| 174 |
+
progress_bar=None,
|
| 175 |
) -> list[str]:
|
| 176 |
"""Run inference on a list of audio file paths or a single audio file with a
|
| 177 |
single query or a list of queries. If multiple queries are provided,
|
|
|
|
| 188 |
Defaults to False.
|
| 189 |
|
| 190 |
Returns:
|
| 191 |
+
list[list[dict]]: List of model outputs for each audio file. Each output is a list of dictionaries
|
| 192 |
+
containing the start time, end time, and prediction for each chunk of audio.
|
| 193 |
|
| 194 |
Raises:
|
| 195 |
ValueError: If the number of audio files and queries do not match.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
"""
|
| 197 |
if isinstance(audios, str) or isinstance(audios, Path):
|
| 198 |
audios = [audios]
|
|
|
|
| 205 |
|
| 206 |
# Run inference
|
| 207 |
results = []
|
| 208 |
+
progress_bar(0, desc="Starting")
|
| 209 |
+
for audio, query in progress_bar.tqdm(
|
| 210 |
+
zip(audios, queries), desc="Generating responses", total=len(audios)
|
| 211 |
+
):
|
| 212 |
output = sliding_window_inference(
|
| 213 |
audio,
|
| 214 |
query,
|
|
|
|
| 228 |
def parse_args() -> argparse.Namespace:
|
| 229 |
parser = argparse.ArgumentParser("Run NatureLM-audio inference")
|
| 230 |
parser.add_argument(
|
| 231 |
+
"-a",
|
| 232 |
+
"--audio",
|
| 233 |
+
type=str,
|
| 234 |
+
required=True,
|
| 235 |
+
help="Path to an audio file or a directory containing audio files",
|
| 236 |
+
)
|
| 237 |
+
parser.add_argument(
|
| 238 |
+
"-q", "--query", type=str, required=True, help="Query for the model"
|
| 239 |
)
|
|
|
|
| 240 |
parser.add_argument(
|
| 241 |
"--cfg-path",
|
| 242 |
type=str,
|
| 243 |
default="configs/inference.yml",
|
| 244 |
help="Path to the configuration file for the model",
|
| 245 |
)
|
|
|
|
| 246 |
parser.add_argument(
|
| 247 |
+
"--output_path",
|
| 248 |
+
type=str,
|
| 249 |
+
default="inference_output.jsonl",
|
| 250 |
+
help="Output path for the results",
|
| 251 |
+
)
|
| 252 |
+
parser.add_argument(
|
| 253 |
+
"--window_length_seconds",
|
| 254 |
+
type=float,
|
| 255 |
+
default=10.0,
|
| 256 |
+
help="Length of the sliding window in seconds",
|
| 257 |
)
|
| 258 |
parser.add_argument(
|
| 259 |
+
"--hop_length_seconds",
|
| 260 |
+
type=float,
|
| 261 |
+
default=10.0,
|
| 262 |
+
help="Hop length for the sliding window in seconds",
|
| 263 |
)
|
| 264 |
args = parser.parse_args()
|
| 265 |
|
|
|
|
| 297 |
audio_path = Path(audio_path)
|
| 298 |
if audio_path.is_dir():
|
| 299 |
audio_paths = []
|
| 300 |
+
print(
|
| 301 |
+
f"Searching for audio files in {str(audio_path)} with extensions {', '.join(_AUDIO_FILE_EXTENSIONS)}"
|
| 302 |
+
)
|
| 303 |
for ext in _AUDIO_FILE_EXTENSIONS:
|
| 304 |
audio_paths.extend(list(audio_path.rglob(f"*{ext}")))
|
| 305 |
|
|
|
|
| 316 |
if not query:
|
| 317 |
raise ValueError("Query cannot be empty")
|
| 318 |
if not audio_paths:
|
| 319 |
+
raise ValueError(
|
| 320 |
+
"No audio files found. Please check the path or file extensions."
|
| 321 |
+
)
|
| 322 |
|
| 323 |
# Load model and config
|
| 324 |
model, cfg = load_model_and_config(cfg_path)
|
| 325 |
|
| 326 |
# Load audio processor
|
| 327 |
+
processor = NatureLMAudioProcessor(
|
| 328 |
+
sample_rate=_SAMPLE_RATE, max_length_seconds=_MAX_LENGTH_SECONDS
|
| 329 |
+
)
|
| 330 |
|
| 331 |
# Run inference
|
| 332 |
results = {"audio_path": [], "output": []}
|
| 333 |
for path in audio_paths:
|
| 334 |
+
output = sliding_window_inference(
|
| 335 |
+
path,
|
| 336 |
+
query,
|
| 337 |
+
processor,
|
| 338 |
+
model,
|
| 339 |
+
cfg,
|
| 340 |
+
window_length_seconds,
|
| 341 |
+
hop_length_seconds,
|
| 342 |
+
)
|
| 343 |
results["audio_path"].append(str(path))
|
| 344 |
results["output"].append(output)
|
| 345 |
print(f"Processed {path}, model output:\n=======\n{output}\n=======\n")
|
NatureLM/logger.py
DELETED
|
@@ -1,190 +0,0 @@
|
|
| 1 |
-
import datetime
|
| 2 |
-
import logging
|
| 3 |
-
import time
|
| 4 |
-
from collections import defaultdict, deque
|
| 5 |
-
|
| 6 |
-
import torch
|
| 7 |
-
import torch.distributed as dist
|
| 8 |
-
import wandb
|
| 9 |
-
|
| 10 |
-
from NatureLM.dist_utils import is_dist_avail_and_initialized, is_main_process
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
class SmoothedValue(object):
|
| 14 |
-
"""Track a series of values and provide access to smoothed values over a
|
| 15 |
-
window or the global series average.
|
| 16 |
-
"""
|
| 17 |
-
|
| 18 |
-
def __init__(self, window_size=20, fmt=None):
|
| 19 |
-
if fmt is None:
|
| 20 |
-
fmt = "{median:.4f} ({global_avg:.4f})"
|
| 21 |
-
self.deque = deque(maxlen=window_size)
|
| 22 |
-
self.total = 0.0
|
| 23 |
-
self.count = 0
|
| 24 |
-
self.fmt = fmt
|
| 25 |
-
|
| 26 |
-
def update(self, value, n=1):
|
| 27 |
-
self.deque.append(value)
|
| 28 |
-
self.count += n
|
| 29 |
-
self.total += value * n
|
| 30 |
-
|
| 31 |
-
def synchronize_between_processes(self):
|
| 32 |
-
"""
|
| 33 |
-
Warning: does not synchronize the deque!
|
| 34 |
-
"""
|
| 35 |
-
if not is_dist_avail_and_initialized():
|
| 36 |
-
return
|
| 37 |
-
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
|
| 38 |
-
dist.barrier()
|
| 39 |
-
dist.all_reduce(t)
|
| 40 |
-
t = t.tolist()
|
| 41 |
-
self.count = int(t[0])
|
| 42 |
-
self.total = t[1]
|
| 43 |
-
|
| 44 |
-
@property
|
| 45 |
-
def median(self):
|
| 46 |
-
d = torch.tensor(list(self.deque))
|
| 47 |
-
return d.median().item()
|
| 48 |
-
|
| 49 |
-
@property
|
| 50 |
-
def avg(self):
|
| 51 |
-
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
| 52 |
-
return d.mean().item()
|
| 53 |
-
|
| 54 |
-
@property
|
| 55 |
-
def global_avg(self):
|
| 56 |
-
return self.total / self.count
|
| 57 |
-
|
| 58 |
-
@property
|
| 59 |
-
def max(self):
|
| 60 |
-
return max(self.deque)
|
| 61 |
-
|
| 62 |
-
@property
|
| 63 |
-
def value(self):
|
| 64 |
-
return self.deque[-1]
|
| 65 |
-
|
| 66 |
-
def __str__(self):
|
| 67 |
-
return self.fmt.format(
|
| 68 |
-
median=self.median,
|
| 69 |
-
avg=self.avg,
|
| 70 |
-
global_avg=self.global_avg,
|
| 71 |
-
max=self.max,
|
| 72 |
-
value=self.value,
|
| 73 |
-
)
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
class MetricLogger(object):
|
| 77 |
-
def __init__(self, delimiter="\t"):
|
| 78 |
-
self.meters = defaultdict(SmoothedValue)
|
| 79 |
-
self.delimiter = delimiter
|
| 80 |
-
|
| 81 |
-
def update(self, **kwargs):
|
| 82 |
-
for k, v in kwargs.items():
|
| 83 |
-
if isinstance(v, torch.Tensor):
|
| 84 |
-
v = v.item()
|
| 85 |
-
assert isinstance(v, (float, int))
|
| 86 |
-
self.meters[k].update(v)
|
| 87 |
-
|
| 88 |
-
def __getattr__(self, attr):
|
| 89 |
-
if attr in self.meters:
|
| 90 |
-
return self.meters[attr]
|
| 91 |
-
if attr in self.__dict__:
|
| 92 |
-
return self.__dict__[attr]
|
| 93 |
-
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))
|
| 94 |
-
|
| 95 |
-
def __str__(self):
|
| 96 |
-
loss_str = []
|
| 97 |
-
for name, meter in self.meters.items():
|
| 98 |
-
loss_str.append("{}: {}".format(name, str(meter)))
|
| 99 |
-
return self.delimiter.join(loss_str)
|
| 100 |
-
|
| 101 |
-
def global_avg(self):
|
| 102 |
-
loss_str = []
|
| 103 |
-
for name, meter in self.meters.items():
|
| 104 |
-
loss_str.append("{}: {:.4f}".format(name, meter.global_avg))
|
| 105 |
-
return self.delimiter.join(loss_str)
|
| 106 |
-
|
| 107 |
-
def synchronize_between_processes(self):
|
| 108 |
-
for meter in self.meters.values():
|
| 109 |
-
meter.synchronize_between_processes()
|
| 110 |
-
|
| 111 |
-
def add_meter(self, name, meter):
|
| 112 |
-
self.meters[name] = meter
|
| 113 |
-
|
| 114 |
-
def log_every(self, iterable, print_freq, header=None, logger=None, start_step=None):
|
| 115 |
-
i = 0
|
| 116 |
-
if not header:
|
| 117 |
-
header = ""
|
| 118 |
-
start_time = time.time()
|
| 119 |
-
end = time.time()
|
| 120 |
-
iter_time = SmoothedValue(fmt="{avg:.4f}")
|
| 121 |
-
data_time = SmoothedValue(fmt="{avg:.4f}")
|
| 122 |
-
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
|
| 123 |
-
log_msg = [
|
| 124 |
-
header,
|
| 125 |
-
"[{0" + space_fmt + "}/{1}]",
|
| 126 |
-
"eta: {eta}",
|
| 127 |
-
"{meters}",
|
| 128 |
-
"time: {time}",
|
| 129 |
-
"data: {data}",
|
| 130 |
-
]
|
| 131 |
-
if torch.cuda.is_available():
|
| 132 |
-
log_msg.append("max mem: {memory:.0f}")
|
| 133 |
-
log_msg = self.delimiter.join(log_msg)
|
| 134 |
-
MB = 1024.0 * 1024.0
|
| 135 |
-
for obj in iterable:
|
| 136 |
-
data_time.update(time.time() - end)
|
| 137 |
-
yield obj
|
| 138 |
-
iter_time.update(time.time() - end)
|
| 139 |
-
if i % print_freq == 0 or i == len(iterable) - 1:
|
| 140 |
-
if is_main_process():
|
| 141 |
-
if logger is not None:
|
| 142 |
-
assert start_step is not None, "start_step is needed to compute global_step!"
|
| 143 |
-
for name, meter in self.meters.items():
|
| 144 |
-
logger.add_scalar("{}".format(name), float(str(meter)), global_step=start_step + i)
|
| 145 |
-
# Log to wandb
|
| 146 |
-
wandb.log({name: float(str(meter)) for name, meter in self.meters.items()}, step=start_step + i)
|
| 147 |
-
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
| 148 |
-
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
| 149 |
-
if torch.cuda.is_available():
|
| 150 |
-
print(
|
| 151 |
-
log_msg.format(
|
| 152 |
-
i,
|
| 153 |
-
len(iterable),
|
| 154 |
-
eta=eta_string,
|
| 155 |
-
meters=str(self),
|
| 156 |
-
time=str(iter_time),
|
| 157 |
-
data=str(data_time),
|
| 158 |
-
memory=torch.cuda.max_memory_allocated() / MB,
|
| 159 |
-
)
|
| 160 |
-
)
|
| 161 |
-
else:
|
| 162 |
-
print(
|
| 163 |
-
log_msg.format(
|
| 164 |
-
i,
|
| 165 |
-
len(iterable),
|
| 166 |
-
eta=eta_string,
|
| 167 |
-
meters=str(self),
|
| 168 |
-
time=str(iter_time),
|
| 169 |
-
data=str(data_time),
|
| 170 |
-
)
|
| 171 |
-
)
|
| 172 |
-
i += 1
|
| 173 |
-
end = time.time()
|
| 174 |
-
total_time = time.time() - start_time
|
| 175 |
-
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| 176 |
-
print("{} Total time: {} ({:.4f} s / it)".format(header, total_time_str, total_time / len(iterable)))
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
class AttrDict(dict):
|
| 180 |
-
def __init__(self, *args, **kwargs):
|
| 181 |
-
super(AttrDict, self).__init__(*args, **kwargs)
|
| 182 |
-
self.__dict__ = self
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
def setup_logger():
|
| 186 |
-
logging.basicConfig(
|
| 187 |
-
level=logging.INFO if is_main_process() else logging.WARN,
|
| 188 |
-
format="%(asctime)s [%(levelname)s] %(message)s",
|
| 189 |
-
handlers=[logging.StreamHandler()],
|
| 190 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
NatureLM/models/NatureLM.py
CHANGED
|
@@ -645,7 +645,7 @@ class NatureLM(nn.Module, PyTorchModelHubMixin):
|
|
| 645 |
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
|
| 646 |
|
| 647 |
with torch.autocast(self.device.type, dtype=torch.bfloat16):
|
| 648 |
-
outputs = self.llama_model.generate(
|
| 649 |
inputs_embeds=embeds.bfloat16(),
|
| 650 |
max_new_tokens=generate_cfg.max_new_tokens,
|
| 651 |
stopping_criteria=stopping_criteria,
|
|
|
|
| 645 |
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
|
| 646 |
|
| 647 |
with torch.autocast(self.device.type, dtype=torch.bfloat16):
|
| 648 |
+
outputs = self.llama_model.generate(
|
| 649 |
inputs_embeds=embeds.bfloat16(),
|
| 650 |
max_new_tokens=generate_cfg.max_new_tokens,
|
| 651 |
stopping_criteria=stopping_criteria,
|
NatureLM/optims.py
DELETED
|
@@ -1,154 +0,0 @@
|
|
| 1 |
-
# This script is from https://github.com/salesforce/LAVIS/blob/main/lavis/common/optims.py
|
| 2 |
-
|
| 3 |
-
import logging
|
| 4 |
-
import math
|
| 5 |
-
|
| 6 |
-
import torch
|
| 7 |
-
|
| 8 |
-
from NatureLM.config import OptimizerConfig
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
class LinearWarmupStepLRScheduler:
|
| 12 |
-
def __init__(
|
| 13 |
-
self,
|
| 14 |
-
optimizer,
|
| 15 |
-
max_epoch,
|
| 16 |
-
min_lr,
|
| 17 |
-
init_lr,
|
| 18 |
-
decay_rate=1,
|
| 19 |
-
warmup_start_lr=-1,
|
| 20 |
-
warmup_steps=0,
|
| 21 |
-
**kwargs,
|
| 22 |
-
):
|
| 23 |
-
self.optimizer = optimizer
|
| 24 |
-
|
| 25 |
-
self.max_epoch = max_epoch
|
| 26 |
-
self.min_lr = min_lr
|
| 27 |
-
|
| 28 |
-
self.decay_rate = decay_rate
|
| 29 |
-
|
| 30 |
-
self.init_lr = init_lr
|
| 31 |
-
self.warmup_steps = warmup_steps
|
| 32 |
-
self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
|
| 33 |
-
|
| 34 |
-
def step(self, cur_epoch, cur_step):
|
| 35 |
-
if cur_epoch == 0:
|
| 36 |
-
warmup_lr_schedule(
|
| 37 |
-
step=cur_step,
|
| 38 |
-
optimizer=self.optimizer,
|
| 39 |
-
max_step=self.warmup_steps,
|
| 40 |
-
init_lr=self.warmup_start_lr,
|
| 41 |
-
max_lr=self.init_lr,
|
| 42 |
-
)
|
| 43 |
-
else:
|
| 44 |
-
step_lr_schedule(
|
| 45 |
-
epoch=cur_epoch,
|
| 46 |
-
optimizer=self.optimizer,
|
| 47 |
-
init_lr=self.init_lr,
|
| 48 |
-
min_lr=self.min_lr,
|
| 49 |
-
decay_rate=self.decay_rate,
|
| 50 |
-
)
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
class LinearWarmupCosineLRScheduler:
|
| 54 |
-
def __init__(
|
| 55 |
-
self,
|
| 56 |
-
optimizer,
|
| 57 |
-
max_epoch,
|
| 58 |
-
iters_per_epoch,
|
| 59 |
-
min_lr,
|
| 60 |
-
init_lr,
|
| 61 |
-
warmup_steps=0,
|
| 62 |
-
warmup_start_lr=-1,
|
| 63 |
-
**kwargs,
|
| 64 |
-
):
|
| 65 |
-
self.optimizer = optimizer
|
| 66 |
-
|
| 67 |
-
self.max_epoch = max_epoch
|
| 68 |
-
self.iters_per_epoch = iters_per_epoch
|
| 69 |
-
self.min_lr = min_lr
|
| 70 |
-
|
| 71 |
-
self.init_lr = init_lr
|
| 72 |
-
self.warmup_steps = warmup_steps
|
| 73 |
-
self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
|
| 74 |
-
|
| 75 |
-
def step(self, cur_epoch, cur_step):
|
| 76 |
-
total_cur_step = cur_epoch * self.iters_per_epoch + cur_step
|
| 77 |
-
if total_cur_step < self.warmup_steps:
|
| 78 |
-
warmup_lr_schedule(
|
| 79 |
-
step=cur_step,
|
| 80 |
-
optimizer=self.optimizer,
|
| 81 |
-
max_step=self.warmup_steps,
|
| 82 |
-
init_lr=self.warmup_start_lr,
|
| 83 |
-
max_lr=self.init_lr,
|
| 84 |
-
)
|
| 85 |
-
else:
|
| 86 |
-
cosine_lr_schedule(
|
| 87 |
-
epoch=total_cur_step,
|
| 88 |
-
optimizer=self.optimizer,
|
| 89 |
-
max_epoch=self.max_epoch * self.iters_per_epoch,
|
| 90 |
-
init_lr=self.init_lr,
|
| 91 |
-
min_lr=self.min_lr,
|
| 92 |
-
)
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
|
| 96 |
-
"""Decay the learning rate"""
|
| 97 |
-
lr = (init_lr - min_lr) * 0.5 * (1.0 + math.cos(math.pi * epoch / max_epoch)) + min_lr
|
| 98 |
-
for param_group in optimizer.param_groups:
|
| 99 |
-
param_group["lr"] = lr
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
|
| 103 |
-
"""Warmup the learning rate"""
|
| 104 |
-
lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
|
| 105 |
-
for param_group in optimizer.param_groups:
|
| 106 |
-
param_group["lr"] = lr
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
|
| 110 |
-
"""Decay the learning rate"""
|
| 111 |
-
lr = max(min_lr, init_lr * (decay_rate**epoch))
|
| 112 |
-
for param_group in optimizer.param_groups:
|
| 113 |
-
param_group["lr"] = lr
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
def get_optimizer(model, config: OptimizerConfig):
|
| 117 |
-
num_parameters = 0
|
| 118 |
-
p_wd, p_non_wd = [], []
|
| 119 |
-
for n, p in model.named_parameters():
|
| 120 |
-
if not p.requires_grad:
|
| 121 |
-
continue # frozen weights
|
| 122 |
-
print(n)
|
| 123 |
-
if p.ndim < 2 or "bias" in n or "ln" in n or "bn" in n:
|
| 124 |
-
p_non_wd.append(p)
|
| 125 |
-
else:
|
| 126 |
-
p_wd.append(p)
|
| 127 |
-
num_parameters += p.data.nelement()
|
| 128 |
-
logging.info("number of trainable parameters: %d" % num_parameters)
|
| 129 |
-
optim_params = [
|
| 130 |
-
{
|
| 131 |
-
"params": p_wd,
|
| 132 |
-
"weight_decay": float(config.weight_decay),
|
| 133 |
-
},
|
| 134 |
-
{"params": p_non_wd, "weight_decay": 0},
|
| 135 |
-
]
|
| 136 |
-
beta2 = config.beta2
|
| 137 |
-
if config.device == "cpu":
|
| 138 |
-
optimizer = torch.optim.AdamW(
|
| 139 |
-
optim_params,
|
| 140 |
-
lr=float(config.init_lr),
|
| 141 |
-
weight_decay=float(config.weight_decay),
|
| 142 |
-
betas=(0.9, beta2),
|
| 143 |
-
)
|
| 144 |
-
else:
|
| 145 |
-
import bitsandbytes as bnb
|
| 146 |
-
|
| 147 |
-
optimizer = bnb.optim.PagedAdamW8bit(
|
| 148 |
-
optim_params,
|
| 149 |
-
lr=float(config.init_lr),
|
| 150 |
-
weight_decay=float(config.weight_decay),
|
| 151 |
-
betas=(0.9, beta2),
|
| 152 |
-
)
|
| 153 |
-
|
| 154 |
-
return optimizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
NatureLM/processors.py
CHANGED
|
@@ -6,7 +6,7 @@ from dataclasses import dataclass, field
|
|
| 6 |
|
| 7 |
import numpy as np
|
| 8 |
import resampy
|
| 9 |
-
import
|
| 10 |
import torch
|
| 11 |
|
| 12 |
|
|
@@ -49,7 +49,7 @@ class NatureLMAudioProcessor:
|
|
| 49 |
def prepare_audio(self, audio: list[float] | np.ndarray | os.PathLike, input_sr: int = None) -> torch.Tensor:
|
| 50 |
"""Prepare an audio array or file path for inference"""
|
| 51 |
if isinstance(audio, str | os.PathLike):
|
| 52 |
-
audio, sr =
|
| 53 |
input_sr = sr
|
| 54 |
elif isinstance(audio, list):
|
| 55 |
audio = np.array(audio)
|
|
|
|
| 6 |
|
| 7 |
import numpy as np
|
| 8 |
import resampy
|
| 9 |
+
import librosa
|
| 10 |
import torch
|
| 11 |
|
| 12 |
|
|
|
|
| 49 |
def prepare_audio(self, audio: list[float] | np.ndarray | os.PathLike, input_sr: int = None) -> torch.Tensor:
|
| 50 |
"""Prepare an audio array or file path for inference"""
|
| 51 |
if isinstance(audio, str | os.PathLike):
|
| 52 |
+
audio, sr = librosa.load(audio, sr=None, mono=False)
|
| 53 |
input_sr = sr
|
| 54 |
elif isinstance(audio, list):
|
| 55 |
audio = np.array(audio)
|
NatureLM/runner.py
DELETED
|
@@ -1,515 +0,0 @@
|
|
| 1 |
-
# This script is based on https://github.com/salesforce/LAVIS/blob/main/lavis/runners/runner_base.py
|
| 2 |
-
|
| 3 |
-
import datetime
|
| 4 |
-
import json
|
| 5 |
-
import logging
|
| 6 |
-
import os
|
| 7 |
-
import time
|
| 8 |
-
from collections import defaultdict
|
| 9 |
-
from pathlib import Path
|
| 10 |
-
|
| 11 |
-
import torch
|
| 12 |
-
import torch.distributed
|
| 13 |
-
import torch.distributed as dist
|
| 14 |
-
import wandb
|
| 15 |
-
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 16 |
-
from torch.utils.tensorboard import SummaryWriter
|
| 17 |
-
|
| 18 |
-
from NatureLM.config import Config
|
| 19 |
-
from NatureLM.dist_utils import get_rank, get_world_size, is_dist_avail_and_initialized, is_main_process, main_process
|
| 20 |
-
from NatureLM.logger import MetricLogger, SmoothedValue
|
| 21 |
-
from NatureLM.optims import LinearWarmupCosineLRScheduler, get_optimizer
|
| 22 |
-
from NatureLM.task_metrics import get_task_metrics
|
| 23 |
-
from NatureLM.utils import get_dataloader, prepare_sample_dist
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
class Runner:
|
| 27 |
-
def __init__(self, cfg: Config, model, datasets, job_id):
|
| 28 |
-
self.config = cfg
|
| 29 |
-
|
| 30 |
-
# log
|
| 31 |
-
device = "cuda:0"
|
| 32 |
-
if is_main_process():
|
| 33 |
-
if self.config.run.wandb_enabled:
|
| 34 |
-
wandb.init(project="earthlm", config=self.config.model_dump())
|
| 35 |
-
else:
|
| 36 |
-
wandb.init(mode="disabled")
|
| 37 |
-
|
| 38 |
-
if "LOCAL_RANK" in os.environ:
|
| 39 |
-
device = int(os.environ["LOCAL_RANK"])
|
| 40 |
-
else:
|
| 41 |
-
device = self.config.run.device
|
| 42 |
-
print(f"device is {device} could have been {self.config.run.device}")
|
| 43 |
-
self.output_dir = Path(self.config.run.output_dir) / job_id
|
| 44 |
-
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 45 |
-
self.log_writter = SummaryWriter(self.output_dir)
|
| 46 |
-
|
| 47 |
-
# settings
|
| 48 |
-
self.device = torch.device(device)
|
| 49 |
-
self.use_distributed = self.config.run.use_distributed
|
| 50 |
-
self.start_epoch = 0
|
| 51 |
-
self.max_epoch = self.config.run.optims.max_epoch
|
| 52 |
-
self.evaluate_only = self.config.run.evaluate
|
| 53 |
-
self.cuda_enabled = self.device.type == "cuda"
|
| 54 |
-
|
| 55 |
-
# test prompt
|
| 56 |
-
self.prompt_template = self.config.model.prompt_template
|
| 57 |
-
|
| 58 |
-
# model
|
| 59 |
-
self._model = model
|
| 60 |
-
torch.nn.SyncBatchNorm.convert_sync_batchnorm(self._model)
|
| 61 |
-
self._model.to(self.device)
|
| 62 |
-
if self.use_distributed:
|
| 63 |
-
self.model = DDP(
|
| 64 |
-
self._model,
|
| 65 |
-
find_unused_parameters=True,
|
| 66 |
-
static_graph=False,
|
| 67 |
-
device_ids=[self.device],
|
| 68 |
-
)
|
| 69 |
-
else:
|
| 70 |
-
self.model = self._model
|
| 71 |
-
|
| 72 |
-
# dataloaders
|
| 73 |
-
self.train_loader = get_dataloader(
|
| 74 |
-
datasets["train"],
|
| 75 |
-
self.config.run,
|
| 76 |
-
is_train=True,
|
| 77 |
-
use_distributed=self.use_distributed,
|
| 78 |
-
)
|
| 79 |
-
self.valid_loader = get_dataloader(
|
| 80 |
-
datasets["valid"],
|
| 81 |
-
self.config.run,
|
| 82 |
-
is_train=False,
|
| 83 |
-
use_distributed=self.use_distributed,
|
| 84 |
-
)
|
| 85 |
-
self.test_loader = get_dataloader(
|
| 86 |
-
datasets["test"],
|
| 87 |
-
self.config.run,
|
| 88 |
-
is_train=False,
|
| 89 |
-
use_distributed=self.use_distributed,
|
| 90 |
-
)
|
| 91 |
-
|
| 92 |
-
# scaler
|
| 93 |
-
self.use_amp = self.config.run.amp
|
| 94 |
-
if self.use_amp:
|
| 95 |
-
self.scaler = torch.cuda.amp.GradScaler()
|
| 96 |
-
else:
|
| 97 |
-
self.scaler = None
|
| 98 |
-
|
| 99 |
-
# optimizer & scheduler
|
| 100 |
-
self.iters_per_epoch = (
|
| 101 |
-
len(self.train_loader) if self.config.run.epoch_based else self.config.run.iters_per_epoch
|
| 102 |
-
)
|
| 103 |
-
self.optimizer = get_optimizer(self.model, self.config.run.optims)
|
| 104 |
-
self.scheduler = LinearWarmupCosineLRScheduler(
|
| 105 |
-
self.optimizer,
|
| 106 |
-
max_epoch=self.max_epoch,
|
| 107 |
-
iters_per_epoch=self.iters_per_epoch,
|
| 108 |
-
min_lr=self.config.run.optims.min_lr,
|
| 109 |
-
init_lr=self.config.run.optims.init_lr,
|
| 110 |
-
warmup_steps=self.config.run.optims.warmup_steps,
|
| 111 |
-
warmup_start_lr=self.config.run.optims.warmup_start_lr,
|
| 112 |
-
)
|
| 113 |
-
|
| 114 |
-
#### augmentations
|
| 115 |
-
# self.rng = random.Random(self.config.run.seed)
|
| 116 |
-
# self.rngnp = np.random.default_rng(seed=self.config.run.seed)
|
| 117 |
-
# self.rngth = torch.Generator(device=args.device)
|
| 118 |
-
# self.rngth.manual_seed(self.config.run.seed)
|
| 119 |
-
# augments = []
|
| 120 |
-
# if self.config.run.augmentations.flip:
|
| 121 |
-
# augments.append(augmentations.Flip(self.config.run.augmentations.flip, rngth=self.rngth, seed=self.config.run.seed))
|
| 122 |
-
# if self.config.run.augmentations.bandmask:
|
| 123 |
-
# augments.append(augmentations.BandMask(self.config.run.augmentations.bandmask, sample_rate=args.sample_rate, rng=self.rng, seed=self.config.run.seed))
|
| 124 |
-
# if self.config.run.augmentations.revecho:
|
| 125 |
-
# augments.append(
|
| 126 |
-
# augmentations.RevEcho(proba=self.config.run.augmentations.revecho,rng=self.rng,seed=self.config.run.seed))
|
| 127 |
-
# self.augment = torch.nn.Sequential(*augments)
|
| 128 |
-
|
| 129 |
-
self.log_config()
|
| 130 |
-
|
| 131 |
-
def unwrap_dist_model(self, model):
|
| 132 |
-
if self.use_distributed:
|
| 133 |
-
return model.module
|
| 134 |
-
else:
|
| 135 |
-
return model
|
| 136 |
-
|
| 137 |
-
def train_epoch(self, epoch):
|
| 138 |
-
self.model.train()
|
| 139 |
-
|
| 140 |
-
metric_logger = MetricLogger(delimiter=" ")
|
| 141 |
-
metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}"))
|
| 142 |
-
metric_logger.add_meter("loss", SmoothedValue(window_size=1, fmt="{value:.4f}"))
|
| 143 |
-
|
| 144 |
-
logging.info("Start training epoch {}, {} iters per inner epoch.".format(epoch, self.iters_per_epoch))
|
| 145 |
-
header = "Train: data epoch: [{}]".format(epoch)
|
| 146 |
-
|
| 147 |
-
# Get gradient clipping parameters from config
|
| 148 |
-
clip_grad_norm = self.config.run.optims.max_grad_norm
|
| 149 |
-
clip_grad_value = self.config.run.optims.max_grad_value
|
| 150 |
-
|
| 151 |
-
for i in metric_logger.log_every(
|
| 152 |
-
range(self.iters_per_epoch),
|
| 153 |
-
self.config.run.log_freq,
|
| 154 |
-
header=header,
|
| 155 |
-
logger=self.log_writter,
|
| 156 |
-
start_step=epoch * self.iters_per_epoch,
|
| 157 |
-
):
|
| 158 |
-
if i >= self.iters_per_epoch:
|
| 159 |
-
break
|
| 160 |
-
|
| 161 |
-
samples = next(self.train_loader)
|
| 162 |
-
|
| 163 |
-
samples = prepare_sample_dist(samples, self.device)
|
| 164 |
-
|
| 165 |
-
#### augmentation
|
| 166 |
-
# if False:
|
| 167 |
-
# samples = self.augment(samples)
|
| 168 |
-
|
| 169 |
-
self.scheduler.step(cur_epoch=epoch, cur_step=i)
|
| 170 |
-
|
| 171 |
-
with torch.autocast(self.device.type, enabled=self.use_amp, dtype=torch.bfloat16):
|
| 172 |
-
loss = self.model(samples)["loss"]
|
| 173 |
-
if torch.isnan(loss):
|
| 174 |
-
print("loss nan", samples)
|
| 175 |
-
# continue
|
| 176 |
-
|
| 177 |
-
if self.use_amp and self.scaler:
|
| 178 |
-
self.scaler.scale(loss).backward()
|
| 179 |
-
else:
|
| 180 |
-
loss.backward()
|
| 181 |
-
|
| 182 |
-
# Apply gradient clipping
|
| 183 |
-
if clip_grad_norm is not None:
|
| 184 |
-
if self.use_amp and self.scaler:
|
| 185 |
-
self.scaler.unscale_(self.optimizer)
|
| 186 |
-
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=clip_grad_norm)
|
| 187 |
-
if clip_grad_value is not None:
|
| 188 |
-
if self.use_amp and self.scaler:
|
| 189 |
-
self.scaler.unscale_(self.optimizer)
|
| 190 |
-
torch.nn.utils.clip_grad_value_(self.model.parameters(), clip_value=clip_grad_value)
|
| 191 |
-
|
| 192 |
-
if (i + 1) % self.config.run.accum_grad_iters == 0:
|
| 193 |
-
if self.use_amp and self.scaler:
|
| 194 |
-
self.scaler.step(self.optimizer)
|
| 195 |
-
self.scaler.update()
|
| 196 |
-
else:
|
| 197 |
-
self.optimizer.step()
|
| 198 |
-
self.optimizer.zero_grad()
|
| 199 |
-
|
| 200 |
-
metric_logger.update(loss=loss.item())
|
| 201 |
-
metric_logger.update(lr=self.optimizer.param_groups[0]["lr"])
|
| 202 |
-
|
| 203 |
-
metric_logger.synchronize_between_processes()
|
| 204 |
-
logging.info("Averaged stats: " + str(metric_logger.global_avg()))
|
| 205 |
-
return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
|
| 206 |
-
|
| 207 |
-
@torch.no_grad()
|
| 208 |
-
def valid_epoch(self, epoch, split, decode=True, save_json=False, decode_ratio=1.0):
|
| 209 |
-
"""
|
| 210 |
-
Decode = True will lead to calculation of custom metrics which are based on text.
|
| 211 |
-
decode_ratio controls the percentage of batches which will have custom metrics computed,
|
| 212 |
-
a speed trade-off due to the cost of the 'generate' method.
|
| 213 |
-
"""
|
| 214 |
-
model = self.unwrap_dist_model(self.model)
|
| 215 |
-
model.eval()
|
| 216 |
-
|
| 217 |
-
dataloader = getattr(self, split + "_loader", None)
|
| 218 |
-
assert dataloader is not None, f"{split}_loader does not exist."
|
| 219 |
-
|
| 220 |
-
metric_logger = MetricLogger(delimiter=" ")
|
| 221 |
-
header = f"Eval: data epoch: [{epoch}]"
|
| 222 |
-
|
| 223 |
-
results_per_task = defaultdict(list) # Store results per task
|
| 224 |
-
overall_results = [] # Store all results for overall metrics
|
| 225 |
-
|
| 226 |
-
# Calculate N based on decode_ratio
|
| 227 |
-
if decode_ratio <= 0.0:
|
| 228 |
-
N = float("inf") # Effectively never run generate
|
| 229 |
-
elif decode_ratio >= 1.0:
|
| 230 |
-
N = 1 # Run generate every batch
|
| 231 |
-
else:
|
| 232 |
-
N = max(int(1 / decode_ratio), 1) # Ensure N is at least 1
|
| 233 |
-
|
| 234 |
-
batch_idx = 0
|
| 235 |
-
|
| 236 |
-
# Initialize overall metrics
|
| 237 |
-
overall_res = {
|
| 238 |
-
"loss": torch.tensor(0.0, device=self.device),
|
| 239 |
-
"correct": torch.tensor(0.0, device=self.device),
|
| 240 |
-
"total": torch.tensor(0.0, device=self.device),
|
| 241 |
-
}
|
| 242 |
-
|
| 243 |
-
# Initialize per-task metrics
|
| 244 |
-
per_task_res = defaultdict(
|
| 245 |
-
lambda: {
|
| 246 |
-
"loss": torch.tensor(0.0, device=self.device),
|
| 247 |
-
"correct": torch.tensor(0.0, device=self.device),
|
| 248 |
-
"total": torch.tensor(0.0, device=self.device),
|
| 249 |
-
"n_sample": 0,
|
| 250 |
-
"predicted_texts": [],
|
| 251 |
-
"gold_texts": [],
|
| 252 |
-
}
|
| 253 |
-
)
|
| 254 |
-
|
| 255 |
-
for samples in metric_logger.log_every(dataloader, self.config.run.log_freq, header=header):
|
| 256 |
-
samples = prepare_sample_dist(samples, self.device)
|
| 257 |
-
|
| 258 |
-
with torch.autocast(self.device.type, enabled=self.use_amp):
|
| 259 |
-
forward_result = model(samples, verbose=True)
|
| 260 |
-
|
| 261 |
-
# Extract batch-level loss and correct counts
|
| 262 |
-
batch_loss = forward_result.get("loss", torch.tensor(0.0, device=self.device))
|
| 263 |
-
batch_correct = forward_result.get("correct", torch.tensor(0.0, device=self.device))
|
| 264 |
-
batch_total = forward_result.get("total", torch.tensor(1.0, device=self.device))
|
| 265 |
-
|
| 266 |
-
batch_size = len(samples["id"])
|
| 267 |
-
|
| 268 |
-
# Update overall metrics with batch-level values
|
| 269 |
-
overall_res["loss"] += batch_loss.detach()
|
| 270 |
-
overall_res["correct"] += batch_correct.detach()
|
| 271 |
-
overall_res["total"] += batch_total.detach()
|
| 272 |
-
|
| 273 |
-
# Decide whether to run generate based on decode_ratio
|
| 274 |
-
if decode and (batch_idx % N == 0):
|
| 275 |
-
prompts = samples.get("prompt", None)
|
| 276 |
-
try:
|
| 277 |
-
generated_texts = model.generate(samples, self.config.generate, prompts=prompts)
|
| 278 |
-
except Exception as e:
|
| 279 |
-
print("error in generation", e)
|
| 280 |
-
generated_texts = [None] * batch_size
|
| 281 |
-
else:
|
| 282 |
-
generated_texts = [None] * batch_size # Placeholder if not decoding
|
| 283 |
-
|
| 284 |
-
# Process per-sample data for per-task metrics and result saving
|
| 285 |
-
for i in range(batch_size):
|
| 286 |
-
task = samples["task"][i]
|
| 287 |
-
|
| 288 |
-
# Collect per-task batch-level metrics
|
| 289 |
-
per_task_res[task]["loss"] += batch_loss.detach()
|
| 290 |
-
per_task_res[task]["correct"] += batch_correct.detach()
|
| 291 |
-
per_task_res[task]["total"] += batch_total.detach()
|
| 292 |
-
per_task_res[task]["n_sample"] += 1
|
| 293 |
-
|
| 294 |
-
res = {
|
| 295 |
-
"id": samples["id"][i],
|
| 296 |
-
"ground_truth": samples["text"][i], # Gold label from dataloader
|
| 297 |
-
"task": task,
|
| 298 |
-
"predicted_text": generated_texts[i],
|
| 299 |
-
}
|
| 300 |
-
|
| 301 |
-
if decode and generated_texts[i] is not None:
|
| 302 |
-
res["prompt"] = samples.get("prompt", [None])[i]
|
| 303 |
-
|
| 304 |
-
results_per_task[task].append(res)
|
| 305 |
-
overall_results.append(res)
|
| 306 |
-
|
| 307 |
-
# Collect texts for custom metrics
|
| 308 |
-
if generated_texts[i] is not None:
|
| 309 |
-
per_task_res[task]["predicted_texts"].append(generated_texts[i])
|
| 310 |
-
per_task_res[task]["gold_texts"].append(samples["text"][i])
|
| 311 |
-
|
| 312 |
-
batch_idx += 1 # Increment batch index
|
| 313 |
-
|
| 314 |
-
if save_json:
|
| 315 |
-
for task, task_results in results_per_task.items():
|
| 316 |
-
self.save_result(task_results, self.output_dir, f"eval_{split}_{task}_epoch_{epoch}")
|
| 317 |
-
# Optionally save overall results
|
| 318 |
-
self.save_result(overall_results, self.output_dir, f"eval_{split}_epoch_{epoch}")
|
| 319 |
-
|
| 320 |
-
# Synchronize metrics across processes if in distributed mode
|
| 321 |
-
if is_dist_avail_and_initialized():
|
| 322 |
-
for key in overall_res:
|
| 323 |
-
dist.all_reduce(overall_res[key])
|
| 324 |
-
|
| 325 |
-
overall_ret = {
|
| 326 |
-
"loss": (overall_res["loss"] / batch_idx).item(),
|
| 327 |
-
"agg_metrics": (overall_res["correct"] / overall_res["total"]).item(),
|
| 328 |
-
}
|
| 329 |
-
|
| 330 |
-
if is_main_process():
|
| 331 |
-
# Log overall metrics
|
| 332 |
-
wandb.log(
|
| 333 |
-
{
|
| 334 |
-
f"{split}_loss": overall_ret["loss"],
|
| 335 |
-
f"{split}_accuracy": overall_ret["agg_metrics"],
|
| 336 |
-
"epoch": epoch,
|
| 337 |
-
}
|
| 338 |
-
)
|
| 339 |
-
|
| 340 |
-
# Compute and log per-task metrics
|
| 341 |
-
for task, res in per_task_res.items():
|
| 342 |
-
if "caption-none" in task:
|
| 343 |
-
continue
|
| 344 |
-
|
| 345 |
-
if self.use_distributed:
|
| 346 |
-
print(f"Rank {dist.get_rank()}, task={task}, ")
|
| 347 |
-
|
| 348 |
-
print(
|
| 349 |
-
f"loss={res['loss'].shape, res['loss'].dtype}, "
|
| 350 |
-
f"correct={res['correct'].shape, res['correct'].dtype}, "
|
| 351 |
-
f"total={res['total'].shape, res['total'].dtype}, "
|
| 352 |
-
f"n_sample={res['n_sample']}"
|
| 353 |
-
)
|
| 354 |
-
|
| 355 |
-
# Synchronize metrics across processes if in distributed mode
|
| 356 |
-
if is_dist_avail_and_initialized():
|
| 357 |
-
dist.all_reduce(res["loss"])
|
| 358 |
-
dist.all_reduce(res["correct"])
|
| 359 |
-
dist.all_reduce(res["total"])
|
| 360 |
-
dist.all_reduce(torch.tensor(res["n_sample"], device=self.device))
|
| 361 |
-
|
| 362 |
-
ret = {
|
| 363 |
-
"loss": (res["loss"] / res["n_sample"]).item(),
|
| 364 |
-
"agg_metrics": (res["correct"] / res["total"]).item(),
|
| 365 |
-
}
|
| 366 |
-
|
| 367 |
-
if is_main_process():
|
| 368 |
-
# Log per-task metrics
|
| 369 |
-
wandb.log(
|
| 370 |
-
{
|
| 371 |
-
f"{split}_{task}_loss": ret["loss"],
|
| 372 |
-
f"{split}_{task}_accuracy": ret["agg_metrics"],
|
| 373 |
-
"epoch": epoch,
|
| 374 |
-
}
|
| 375 |
-
)
|
| 376 |
-
|
| 377 |
-
# Get and compute custom metrics for this task
|
| 378 |
-
metrics_list = get_task_metrics(task)
|
| 379 |
-
predicted_texts = res["predicted_texts"]
|
| 380 |
-
gold_texts = res["gold_texts"]
|
| 381 |
-
for metric in metrics_list:
|
| 382 |
-
if predicted_texts and gold_texts:
|
| 383 |
-
metric_value = metric.compute_metric(predicted_texts, gold_texts)
|
| 384 |
-
metric_name = metric.__class__.__name__
|
| 385 |
-
wandb.log(
|
| 386 |
-
{
|
| 387 |
-
f"{split}_{task}_{metric_name}": metric_value,
|
| 388 |
-
"epoch": epoch,
|
| 389 |
-
}
|
| 390 |
-
)
|
| 391 |
-
return overall_ret # Return overall metrics
|
| 392 |
-
|
| 393 |
-
def save_result(self, result, result_dir, filename):
|
| 394 |
-
result_file = os.path.join(result_dir, "%s_rank%d.json" % (filename, get_rank()))
|
| 395 |
-
final_result_file = os.path.join(result_dir, "%s.json" % filename)
|
| 396 |
-
|
| 397 |
-
try:
|
| 398 |
-
json.dump(result, open(result_file, "w"), ensure_ascii=False)
|
| 399 |
-
except Exception as e:
|
| 400 |
-
logging.warning(f"Error saving {result_file}. Error: {e}")
|
| 401 |
-
json.dump(result, open(result_file, "w", encoding="utf-8"), ensure_ascii=False)
|
| 402 |
-
|
| 403 |
-
# if is_dist_avail_and_initialized():
|
| 404 |
-
# dist.barrier()
|
| 405 |
-
|
| 406 |
-
if is_main_process():
|
| 407 |
-
logging.info("rank %d starts merging results." % get_rank())
|
| 408 |
-
result = []
|
| 409 |
-
|
| 410 |
-
for rank in range(get_world_size()):
|
| 411 |
-
result_file = os.path.join(result_dir, "%s_rank%d.json" % (filename, rank))
|
| 412 |
-
try:
|
| 413 |
-
res = json.load(open(result_file, "r"))
|
| 414 |
-
except Exception as e:
|
| 415 |
-
logging.warning(f"Error reading {result_file}. Error: {e}")
|
| 416 |
-
res = json.load(open(result_file, "r", encoding="utf-8"))
|
| 417 |
-
result += res
|
| 418 |
-
|
| 419 |
-
try:
|
| 420 |
-
json.dump(result, open(final_result_file, "w"), ensure_ascii=False)
|
| 421 |
-
except Exception as e:
|
| 422 |
-
logging.warning(f"Error saving {final_result_file}. Error: {e}")
|
| 423 |
-
json.dump(
|
| 424 |
-
result,
|
| 425 |
-
open(final_result_file, "w", encoding="utf-8"),
|
| 426 |
-
ensure_ascii=False,
|
| 427 |
-
)
|
| 428 |
-
|
| 429 |
-
print("result file saved to %s" % final_result_file)
|
| 430 |
-
|
| 431 |
-
def train(self):
|
| 432 |
-
start_time = time.time()
|
| 433 |
-
best_agg_metric = 0
|
| 434 |
-
best_epoch = 0
|
| 435 |
-
|
| 436 |
-
for cur_epoch in range(self.start_epoch, self.max_epoch):
|
| 437 |
-
if self.evaluate_only:
|
| 438 |
-
break
|
| 439 |
-
|
| 440 |
-
# training phase
|
| 441 |
-
logging.info("Training Phase")
|
| 442 |
-
train_stats = self.train_epoch(cur_epoch)
|
| 443 |
-
self.log_stats(train_stats, split_name="train")
|
| 444 |
-
|
| 445 |
-
# validating phase
|
| 446 |
-
logging.info("Validating Phase")
|
| 447 |
-
valid_log = self.valid_epoch(
|
| 448 |
-
cur_epoch,
|
| 449 |
-
"valid",
|
| 450 |
-
decode=self.config.run.custom_metrics,
|
| 451 |
-
save_json=False,
|
| 452 |
-
decode_ratio=self.config.run.decode_ratio,
|
| 453 |
-
)
|
| 454 |
-
if valid_log is not None:
|
| 455 |
-
if is_main_process():
|
| 456 |
-
agg_metrics = valid_log["agg_metrics"]
|
| 457 |
-
if agg_metrics > best_agg_metric:
|
| 458 |
-
best_agg_metric = agg_metrics
|
| 459 |
-
best_epoch = cur_epoch
|
| 460 |
-
self.save_checkpoint(cur_epoch, is_best=True)
|
| 461 |
-
|
| 462 |
-
valid_log.update({"best_epoch": best_epoch})
|
| 463 |
-
self.log_stats(valid_log, split_name="valid")
|
| 464 |
-
self.save_checkpoint(cur_epoch, is_best=False)
|
| 465 |
-
|
| 466 |
-
# if self.use_distributed:
|
| 467 |
-
# dist.barrier()
|
| 468 |
-
|
| 469 |
-
# testing phase
|
| 470 |
-
if self.evaluate_only:
|
| 471 |
-
self.valid_epoch("best", "test", decode=True, save_json=True)
|
| 472 |
-
|
| 473 |
-
total_time = time.time() - start_time
|
| 474 |
-
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| 475 |
-
logging.info("Training time {}".format(total_time_str))
|
| 476 |
-
|
| 477 |
-
@main_process
|
| 478 |
-
def log_config(self):
|
| 479 |
-
with open(os.path.join(self.output_dir, "log.txt"), "a") as f:
|
| 480 |
-
f.write(json.dumps(self.config.model_dump_json(), indent=4) + "\n")
|
| 481 |
-
|
| 482 |
-
@main_process
|
| 483 |
-
def log_stats(self, stats, split_name):
|
| 484 |
-
if isinstance(stats, dict):
|
| 485 |
-
log_stats = {**{f"{split_name}_{k}": v for k, v in stats.items()}}
|
| 486 |
-
with open(os.path.join(self.output_dir, "log.txt"), "a") as f:
|
| 487 |
-
f.write(json.dumps(log_stats) + "\n")
|
| 488 |
-
elif isinstance(stats, list):
|
| 489 |
-
pass
|
| 490 |
-
|
| 491 |
-
@main_process
|
| 492 |
-
def save_checkpoint(self, cur_epoch, is_best=False):
|
| 493 |
-
"""
|
| 494 |
-
Save the checkpoint at the current epoch.
|
| 495 |
-
"""
|
| 496 |
-
model_no_ddp = self.unwrap_dist_model(self.model)
|
| 497 |
-
param_grad_dic = {k: v.requires_grad for (k, v) in model_no_ddp.named_parameters()}
|
| 498 |
-
state_dict = model_no_ddp.state_dict()
|
| 499 |
-
for k in list(state_dict.keys()):
|
| 500 |
-
if k in param_grad_dic.keys() and not param_grad_dic[k]:
|
| 501 |
-
# delete parameters that do not require gradient
|
| 502 |
-
del state_dict[k]
|
| 503 |
-
save_obj = {
|
| 504 |
-
"model": state_dict,
|
| 505 |
-
"optimizer": self.optimizer.state_dict(),
|
| 506 |
-
"config": dict(self.config),
|
| 507 |
-
"scaler": self.scaler.state_dict() if self.scaler else None,
|
| 508 |
-
"epoch": cur_epoch,
|
| 509 |
-
}
|
| 510 |
-
save_to = os.path.join(
|
| 511 |
-
self.output_dir,
|
| 512 |
-
"checkpoint_{}.pth".format("best" if is_best else cur_epoch),
|
| 513 |
-
)
|
| 514 |
-
logging.info("Saving checkpoint at epoch {} to {}.".format(cur_epoch, save_to))
|
| 515 |
-
torch.save(save_obj, save_to)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
NatureLM/storage_utils.py
DELETED
|
@@ -1,26 +0,0 @@
|
|
| 1 |
-
import logging
|
| 2 |
-
import os
|
| 3 |
-
from functools import lru_cache
|
| 4 |
-
from typing import Union
|
| 5 |
-
|
| 6 |
-
import cloudpathlib
|
| 7 |
-
from google.cloud.storage.client import Client
|
| 8 |
-
|
| 9 |
-
logger = logging.getLogger(__name__)
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
def is_gcs_path(path: Union[str, os.PathLike]) -> bool:
|
| 13 |
-
return str(path).startswith("gs://")
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
@lru_cache(maxsize=1)
|
| 17 |
-
def _get_client():
|
| 18 |
-
return cloudpathlib.GSClient(storage_client=Client())
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
try:
|
| 22 |
-
_gcp_storage_client = _get_client()
|
| 23 |
-
except Exception:
|
| 24 |
-
logger.warning("Failed to initialize GCS client." "Training wont be able to use GSPath or R2Path without a client.")
|
| 25 |
-
_gcp_storage_client = None
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
NatureLM/task_metric_utils.py
DELETED
|
@@ -1,283 +0,0 @@
|
|
| 1 |
-
# Taken from DCASE 2021 Task 5 evaluation source code
|
| 2 |
-
# https://github.com/c4dm/dcase-few-shot-bioacoustic
|
| 3 |
-
# MIT License
|
| 4 |
-
|
| 5 |
-
import mir_eval
|
| 6 |
-
import numpy as np
|
| 7 |
-
import scipy
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
def fast_intersect(ref, est):
|
| 11 |
-
"""Find all intersections between reference events and estimated events (fast).
|
| 12 |
-
Best-case complexity: O(N log N + M log M) where N=length(ref) and M=length(est)
|
| 13 |
-
Parameters
|
| 14 |
-
----------
|
| 15 |
-
ref: np.ndarray [shape=(2, n)], real-valued
|
| 16 |
-
Array of reference events. Each column is an event.
|
| 17 |
-
The first row denotes onset times and the second row denotes offset times.
|
| 18 |
-
est: np.ndarray [shape=(2, m)], real-valued
|
| 19 |
-
Array of estimated events. Each column is an event.
|
| 20 |
-
The first row denotes onset times and the second row denotes offset times.
|
| 21 |
-
Returns
|
| 22 |
-
-------
|
| 23 |
-
matches: list of sets, length n, integer-valued
|
| 24 |
-
Property: matches[i] contains the set of all indices j such that
|
| 25 |
-
(ref[0, i]<=est[1, j]) AND (ref[1, i]>=est[0, j])
|
| 26 |
-
"""
|
| 27 |
-
ref_on_argsort = np.argsort(ref[0, :])
|
| 28 |
-
ref_off_argsort = np.argsort(ref[1, :])
|
| 29 |
-
|
| 30 |
-
est_on_argsort = np.argsort(est[0, :])
|
| 31 |
-
est_off_argsort = np.argsort(est[1, :])
|
| 32 |
-
|
| 33 |
-
est_on_maxindex = est.shape[1]
|
| 34 |
-
est_off_minindex = 0
|
| 35 |
-
estref_matches = [set()] * ref.shape[1]
|
| 36 |
-
refest_matches = [set()] * ref.shape[1]
|
| 37 |
-
for ref_id in range(ref.shape[1]):
|
| 38 |
-
ref_onset = ref[0, ref_on_argsort[ref_id]]
|
| 39 |
-
est_off_sorted = est[1, est_off_argsort[est_off_minindex:]]
|
| 40 |
-
search_result = np.searchsorted(est_off_sorted, ref_onset, side="left")
|
| 41 |
-
est_off_minindex += search_result
|
| 42 |
-
refest_match = est_off_argsort[est_off_minindex:]
|
| 43 |
-
refest_matches[ref_on_argsort[ref_id]] = set(refest_match)
|
| 44 |
-
|
| 45 |
-
ref_offset = ref[1, ref_off_argsort[-1 - ref_id]]
|
| 46 |
-
est_on_sorted = est[0, est_on_argsort[: (1 + est_on_maxindex)]]
|
| 47 |
-
search_result = np.searchsorted(est_on_sorted, ref_offset, side="right")
|
| 48 |
-
est_on_maxindex = search_result - 1
|
| 49 |
-
estref_match = est_on_argsort[: (1 + est_on_maxindex)]
|
| 50 |
-
estref_matches[ref_off_argsort[-1 - ref_id]] = set(estref_match)
|
| 51 |
-
|
| 52 |
-
zip_iterator = zip(refest_matches, estref_matches)
|
| 53 |
-
matches = [x.intersection(y) for (x, y) in zip_iterator]
|
| 54 |
-
return matches
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
def iou(ref, est, method="fast"):
|
| 58 |
-
"""Compute pairwise "intersection over union" (IOU) metric between reference
|
| 59 |
-
events and estimated events.
|
| 60 |
-
Let us denote by a_i and b_i the onset and offset of reference event i.
|
| 61 |
-
Let us denote by u_j and v_j the onset and offset of estimated event j.
|
| 62 |
-
The IOU between events i and j is defined as
|
| 63 |
-
(min(b_i, v_j)-max(a_i, u_j)) / (max(b_i, v_j)-min(a_i, u_j))
|
| 64 |
-
if the events are non-disjoint, and equal to zero otherwise.
|
| 65 |
-
Parameters
|
| 66 |
-
----------
|
| 67 |
-
ref: np.ndarray [shape=(2, n)], real-valued
|
| 68 |
-
Array of reference events. Each column is an event.
|
| 69 |
-
The first row denotes onset times and the second row denotes offset times.
|
| 70 |
-
est: np.ndarray [shape=(2, m)], real-valued
|
| 71 |
-
Array of estimated events. Each column is an event.
|
| 72 |
-
The first row denotes onset times and the second row denotes offset times.
|
| 73 |
-
method: str, optional.
|
| 74 |
-
If "fast" (default), computes pairwise intersections via a custom
|
| 75 |
-
dynamic programming algorithm, see fast_intersect.
|
| 76 |
-
If "slow", computes pairwise intersections via bruteforce quadratic
|
| 77 |
-
search, see slow_intersect.
|
| 78 |
-
Returns
|
| 79 |
-
-------
|
| 80 |
-
S: scipy.sparse.dok.dok_matrix, real-valued
|
| 81 |
-
Sparse 2-D matrix. S[i,j] contains the IOU between ref[i] and est[j]
|
| 82 |
-
if these events are non-disjoint and zero otherwise.
|
| 83 |
-
"""
|
| 84 |
-
n_refs = ref.shape[1]
|
| 85 |
-
n_ests = est.shape[1]
|
| 86 |
-
S = scipy.sparse.dok_matrix((n_refs, n_ests))
|
| 87 |
-
|
| 88 |
-
if method == "fast":
|
| 89 |
-
matches = fast_intersect(ref, est)
|
| 90 |
-
elif method == "slow":
|
| 91 |
-
matches = slow_intersect(ref, est)
|
| 92 |
-
|
| 93 |
-
for ref_id in range(n_refs):
|
| 94 |
-
matching_ests = matches[ref_id]
|
| 95 |
-
ref_on = ref[0, ref_id]
|
| 96 |
-
ref_off = ref[1, ref_id]
|
| 97 |
-
|
| 98 |
-
for matching_est_id in matching_ests:
|
| 99 |
-
est_on = est[0, matching_est_id]
|
| 100 |
-
est_off = est[1, matching_est_id]
|
| 101 |
-
intersection = min(ref_off, est_off) - max(ref_on, est_on)
|
| 102 |
-
union = max(ref_off, est_off) - min(ref_on, est_on)
|
| 103 |
-
intersection_over_union = intersection / union
|
| 104 |
-
S[ref_id, matching_est_id] = intersection_over_union
|
| 105 |
-
|
| 106 |
-
return S
|
| 107 |
-
|
| 108 |
-
def compute_intersection(ref, est, method="fast"):
|
| 109 |
-
"""Compute pairwise intersection between reference
|
| 110 |
-
events and estimated events.
|
| 111 |
-
Let us denote by a_i and b_i the onset and offset of reference event i.
|
| 112 |
-
Let us denote by u_j and v_j the onset and offset of estimated event j.
|
| 113 |
-
The Intersection between events i and j is defined as
|
| 114 |
-
(min(b_i, v_j)-max(a_i, u_j))
|
| 115 |
-
if the events are non-disjoint, and equal to zero otherwise.
|
| 116 |
-
Parameters
|
| 117 |
-
----------
|
| 118 |
-
ref: np.ndarray [shape=(2, n)], real-valued
|
| 119 |
-
Array of reference events. Each column is an event.
|
| 120 |
-
The first row denotes onset times and the second row denotes offset times.
|
| 121 |
-
est: np.ndarray [shape=(2, m)], real-valued
|
| 122 |
-
Array of estimated events. Each column is an event.
|
| 123 |
-
The first row denotes onset times and the second row denotes offset times.
|
| 124 |
-
method: str, optional.
|
| 125 |
-
If "fast" (default), computes pairwise intersections via a custom
|
| 126 |
-
dynamic programming algorithm, see fast_intersect.
|
| 127 |
-
If "slow", computes pairwise intersections via bruteforce quadratic
|
| 128 |
-
search, see slow_intersect.
|
| 129 |
-
Returns
|
| 130 |
-
-------
|
| 131 |
-
S: scipy.sparse.dok.dok_matrix, real-valued
|
| 132 |
-
Sparse 2-D matrix. S[i,j] contains the Intersection between ref[i] and est[j]
|
| 133 |
-
if these events are non-disjoint and zero otherwise.
|
| 134 |
-
"""
|
| 135 |
-
n_refs = ref.shape[1]
|
| 136 |
-
n_ests = est.shape[1]
|
| 137 |
-
S = scipy.sparse.dok_matrix((n_refs, n_ests))
|
| 138 |
-
|
| 139 |
-
if method == "fast":
|
| 140 |
-
matches = fast_intersect(ref, est)
|
| 141 |
-
elif method == "slow":
|
| 142 |
-
matches = slow_intersect(ref, est)
|
| 143 |
-
|
| 144 |
-
for ref_id in range(n_refs):
|
| 145 |
-
matching_ests = matches[ref_id]
|
| 146 |
-
ref_on = ref[0, ref_id]
|
| 147 |
-
ref_off = ref[1, ref_id]
|
| 148 |
-
|
| 149 |
-
for matching_est_id in matching_ests:
|
| 150 |
-
est_on = est[0, matching_est_id]
|
| 151 |
-
est_off = est[1, matching_est_id]
|
| 152 |
-
intersection = min(ref_off, est_off) - max(ref_on, est_on)
|
| 153 |
-
# union = max(ref_off, est_off) - min(ref_on, est_on)
|
| 154 |
-
# intersection_over_union = intersection / union
|
| 155 |
-
S[ref_id, matching_est_id] = intersection #_over_union
|
| 156 |
-
|
| 157 |
-
return S
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
def match_events(ref, est, min_iou=0.0, method="fast"):
|
| 161 |
-
"""
|
| 162 |
-
Compute a maximum matching between reference and estimated event times,
|
| 163 |
-
subject to a criterion of minimum intersection-over-union (IOU).
|
| 164 |
-
Given two lists of events ``ref`` (reference) and ``est`` (estimated),
|
| 165 |
-
we seek the largest set of correspondences ``(ref[i], est[j])`` such that
|
| 166 |
-
``iou(ref[i], est[j]) <= min_iou``
|
| 167 |
-
and such that each ``ref[i]`` and ``est[j]`` is matched at most once.
|
| 168 |
-
This function is strongly inspired by mir_eval.onset.util.match_events.
|
| 169 |
-
It relies on mir_eval's implementation of the Hopcroft-Karp algorithm from
|
| 170 |
-
maximum bipartite graph matching. However, one important difference is that
|
| 171 |
-
mir_eval's distance function relies purely on onset times, whereas this function
|
| 172 |
-
considers both onset times and offset times to compute the IOU metric between
|
| 173 |
-
reference events and estimated events.
|
| 174 |
-
Parameters
|
| 175 |
-
----------
|
| 176 |
-
ref: np.ndarray [shape=(2, n)], real-valued
|
| 177 |
-
Array of reference events. Each column is an event.
|
| 178 |
-
The first row denotes onset times and the second row denotes offset times.
|
| 179 |
-
est: np.ndarray [shape=(2, m)], real-valued
|
| 180 |
-
Array of estimated events. Each column is an event.
|
| 181 |
-
The first row denotes onset times and the second row denotes offset times.
|
| 182 |
-
min_iou: real number in [0, 1). Default: 0.
|
| 183 |
-
Threshold for minimum amount of intersection over union (IOU) to match
|
| 184 |
-
any two events. See the iou method for implementation details.
|
| 185 |
-
method: str, optional.
|
| 186 |
-
If "fast" (default), computes pairwise intersections via a custom
|
| 187 |
-
dynamic programming algorithm, see fast_intersect.
|
| 188 |
-
If "slow", computes pairwise intersections via bruteforce quadratic
|
| 189 |
-
search, see slow_intersect.
|
| 190 |
-
Returns
|
| 191 |
-
-------
|
| 192 |
-
matching : list of tuples
|
| 193 |
-
Every tuple corresponds to a match between one reference event and
|
| 194 |
-
one estimated event.
|
| 195 |
-
``matching[i] == (i, j)`` where ``ref[i]`` matches ``est[j]``.
|
| 196 |
-
Note that all values i and j appear at most once in the list.
|
| 197 |
-
"""
|
| 198 |
-
|
| 199 |
-
# Intersect reference events and estimated events
|
| 200 |
-
S = iou(ref, est, method=method)
|
| 201 |
-
|
| 202 |
-
# Threshold intersection-over-union (IOU) ratio
|
| 203 |
-
S_bool = scipy.sparse.dok_matrix(S > min_iou)
|
| 204 |
-
hits = S_bool.keys()
|
| 205 |
-
|
| 206 |
-
# Construct the bipartite graph
|
| 207 |
-
G = {}
|
| 208 |
-
for ref_i, est_i in hits:
|
| 209 |
-
if est_i not in G:
|
| 210 |
-
G[est_i] = []
|
| 211 |
-
G[est_i].append(ref_i)
|
| 212 |
-
|
| 213 |
-
# Apply Hopcroft-Karp algorithm (from mir_eval package)
|
| 214 |
-
# to obtain maximum bipartite graph matching
|
| 215 |
-
matching = sorted(mir_eval.util._bipartite_match(G).items())
|
| 216 |
-
return matching
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
def slow_intersect(ref, est):
|
| 220 |
-
"""Find all intersections between reference events and estimated events (slow).
|
| 221 |
-
Best-case complexity: O(N*M) where N=ref.shape[1] and M=est.shape[1]
|
| 222 |
-
Parameters
|
| 223 |
-
----------
|
| 224 |
-
ref: np.ndarray [shape=(2, n)], real-valued
|
| 225 |
-
Array of reference events. Each column is an event.
|
| 226 |
-
The first row denotes onset times and the second row denotes offset times.
|
| 227 |
-
est: np.ndarray [shape=(2, m)], real-valued
|
| 228 |
-
Array of estimated events. Each column is an event.
|
| 229 |
-
The first row denotes onset times and the second row denotes offset times.
|
| 230 |
-
Returns
|
| 231 |
-
-------
|
| 232 |
-
matches: list of sets, length n, integer-valued
|
| 233 |
-
Property: matches[i] contains the set of all indices j such that
|
| 234 |
-
(ref[0, i]<=est[1, j]) AND (ref[1, i]>=est[0, j])
|
| 235 |
-
"""
|
| 236 |
-
matches = []
|
| 237 |
-
for i in range(ref.shape[1]):
|
| 238 |
-
matches.append(
|
| 239 |
-
set(
|
| 240 |
-
[
|
| 241 |
-
j
|
| 242 |
-
for j in range(est.shape[1])
|
| 243 |
-
if ((ref[0, i] <= est[1, j]) and (ref[1, i] >= est[0, j]))
|
| 244 |
-
]
|
| 245 |
-
)
|
| 246 |
-
)
|
| 247 |
-
return
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
def frames_to_st_dict(x, sr=16000):
|
| 251 |
-
# x : Tensor of shape (batch, time) or (time,). Entries are 2 (POS), 1 (UNK), and 0 (NEG).
|
| 252 |
-
# returns a list of dicts {"Begin Time (s)" : [...], "End Time (s)" : [...], "Annotation" : [...]} if batch dim exists, or a single dict
|
| 253 |
-
|
| 254 |
-
if len(x.size()) == 2:
|
| 255 |
-
outs = []
|
| 256 |
-
for i in range(x.size(0)):
|
| 257 |
-
x_sub = x[i,:]
|
| 258 |
-
outs.append(_frames_to_st_dict_single(x_sub, sr=sr))
|
| 259 |
-
return outs
|
| 260 |
-
else:
|
| 261 |
-
return _frames_to_st_dict_single(x, sr=sr)
|
| 262 |
-
|
| 263 |
-
def _frames_to_st_dict_single(x, sr=16000):
|
| 264 |
-
d = {"Begin Time (s)" : [], "End Time (s)" : [], "Annotation" : []}
|
| 265 |
-
|
| 266 |
-
for label_i in [1,2]:
|
| 267 |
-
|
| 268 |
-
labels = x.numpy() == label_i # POS : 2, UNK : 1, NEG : 0
|
| 269 |
-
|
| 270 |
-
starts = np.where((~labels[:-1]) & (labels[1:]))[0] + 1
|
| 271 |
-
if labels[0]:
|
| 272 |
-
starts = np.insert(starts, 0, 0)
|
| 273 |
-
|
| 274 |
-
ends = np.where((labels[:-1]) & (~labels[1:]))[0] + 1
|
| 275 |
-
if labels[-1]:
|
| 276 |
-
ends = np.append(ends, len(labels))
|
| 277 |
-
|
| 278 |
-
for start, end in zip(starts, ends):
|
| 279 |
-
d["Begin Time (s)"].append(start/sr)
|
| 280 |
-
d["End Time (s)"].append(end/sr)
|
| 281 |
-
d["Annotation"].append("POS" if label_i == 2 else "UNK")
|
| 282 |
-
|
| 283 |
-
return d
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
NatureLM/task_metrics.py
DELETED
|
@@ -1,128 +0,0 @@
|
|
| 1 |
-
import re
|
| 2 |
-
from abc import ABC, abstractmethod
|
| 3 |
-
from typing import List, Tuple
|
| 4 |
-
|
| 5 |
-
import numpy as np
|
| 6 |
-
|
| 7 |
-
from NatureLM.task_metric_utils import match_events
|
| 8 |
-
|
| 9 |
-
# Assume the following functions are imported from the reference implementations:
|
| 10 |
-
# - match_events
|
| 11 |
-
# - iou
|
| 12 |
-
# - fast_intersect
|
| 13 |
-
# - slow_intersect
|
| 14 |
-
# - compute_intersection
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
class Metric(ABC):
|
| 18 |
-
@abstractmethod
|
| 19 |
-
def compute_metric(self, predicted_texts: List[str], gold_texts: List[str]) -> float:
|
| 20 |
-
pass
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
class ExactAccuracy(Metric):
|
| 24 |
-
"""Exact-match accuracy metric."""
|
| 25 |
-
|
| 26 |
-
def compute_metric(self, predicted_texts: List[str], gold_texts: List[str]) -> float:
|
| 27 |
-
predicted_texts = [pt.lower().strip() for pt in predicted_texts]
|
| 28 |
-
gold_texts = [gt.lower().strip() for gt in gold_texts]
|
| 29 |
-
correct = sum(p == g for p, g in zip(predicted_texts, gold_texts))
|
| 30 |
-
return correct / len(gold_texts) if gold_texts else 0.0
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
class FewShot(Metric):
|
| 34 |
-
"""Few-shot learning metric based on event matching using IoU."""
|
| 35 |
-
|
| 36 |
-
def compute_metric(self, predicted_texts: List[str], gold_texts: List[str]) -> float:
|
| 37 |
-
# Initialize counts
|
| 38 |
-
total_TP = 0
|
| 39 |
-
total_FP = 0
|
| 40 |
-
total_FN = 0
|
| 41 |
-
|
| 42 |
-
for pred_text, gold_text in zip(predicted_texts, gold_texts):
|
| 43 |
-
# Extract events from texts
|
| 44 |
-
pred_events = parse_timestamps_from_text(pred_text)
|
| 45 |
-
gold_events = parse_timestamps_from_text(gold_text)
|
| 46 |
-
|
| 47 |
-
# Convert events to numpy arrays for match_events function
|
| 48 |
-
# Each event is (start_time, end_time), need to transpose to shape (2, n)
|
| 49 |
-
pred_array = np.array(pred_events).T if pred_events else np.empty((2, 0))
|
| 50 |
-
gold_array = np.array(gold_events).T if gold_events else np.empty((2, 0))
|
| 51 |
-
|
| 52 |
-
# Use match_events function from the reference implementation
|
| 53 |
-
matches = match_events(gold_array, pred_array, min_iou=0.5, method="fast")
|
| 54 |
-
|
| 55 |
-
TP = len(matches)
|
| 56 |
-
FP = len(pred_events) - TP
|
| 57 |
-
FN = len(gold_events) - TP
|
| 58 |
-
|
| 59 |
-
total_TP += TP
|
| 60 |
-
total_FP += FP
|
| 61 |
-
total_FN += FN
|
| 62 |
-
|
| 63 |
-
# Compute precision, recall, and F1 score
|
| 64 |
-
precision = total_TP / (total_TP + total_FP) if (total_TP + total_FP) > 0 else 0.0
|
| 65 |
-
recall = total_TP / (total_TP + total_FN) if (total_TP + total_FN) > 0 else 0.0
|
| 66 |
-
f1_score = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
|
| 67 |
-
|
| 68 |
-
return f1_score
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
class NoneAccuracy(Metric):
|
| 72 |
-
"""Accuracy for cases where 'None' is the correct answer."""
|
| 73 |
-
|
| 74 |
-
def compute_metric(self, predicted_texts: List[str], gold_texts: List[str]) -> float:
|
| 75 |
-
# Normalize texts
|
| 76 |
-
predicted_texts = [pt.lower().strip() for pt in predicted_texts]
|
| 77 |
-
gold_texts = [gt.lower().strip() for gt in gold_texts]
|
| 78 |
-
# Filter indices where gold_text is 'none'
|
| 79 |
-
indices = [i for i, gt in enumerate(gold_texts) if gt == "none"]
|
| 80 |
-
if not indices:
|
| 81 |
-
return 0.0 # No 'None' cases in gold_texts
|
| 82 |
-
correct = sum(predicted_texts[i] == "none" for i in indices)
|
| 83 |
-
return correct / len(indices)
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
class MultipleSpeciesAccuracy(Metric):
|
| 87 |
-
"""Accuracy for cases where the correct answer has at least one comma (multiple species)."""
|
| 88 |
-
|
| 89 |
-
def compute_metric(self, predicted_texts: List[str], gold_texts: List[str]) -> float:
|
| 90 |
-
# Normalize texts
|
| 91 |
-
predicted_texts = [pt.lower().strip() for pt in predicted_texts]
|
| 92 |
-
gold_texts = [gt.lower().strip() for gt in gold_texts]
|
| 93 |
-
# Filter indices where gold_text contains at least one comma
|
| 94 |
-
indices = [i for i, gt in enumerate(gold_texts) if "," in gt]
|
| 95 |
-
if not indices:
|
| 96 |
-
return 0.0 # No multiple-species cases in gold_texts
|
| 97 |
-
correct = sum(predicted_texts[i] == gold_texts[i] for i in indices)
|
| 98 |
-
return correct / len(indices)
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
def get_task_metrics(task: str) -> List[Metric]:
|
| 102 |
-
"""Get a list of metric instances appropriate for the given task."""
|
| 103 |
-
all_metrics = []
|
| 104 |
-
metrics_dict = {}
|
| 105 |
-
|
| 106 |
-
if "classification" in task:
|
| 107 |
-
metrics_dict["ExactAccuracy"] = ExactAccuracy()
|
| 108 |
-
if "fewshot" in task:
|
| 109 |
-
metrics_dict["FewShot"] = FewShot()
|
| 110 |
-
if "detection" in task:
|
| 111 |
-
metrics_dict["ExactAccuracy"] = ExactAccuracy() # Ensures no duplicate
|
| 112 |
-
metrics_dict["NoneAccuracy"] = NoneAccuracy()
|
| 113 |
-
metrics_dict["MultipleSpeciesAccuracy"] = MultipleSpeciesAccuracy()
|
| 114 |
-
|
| 115 |
-
all_metrics = list(metrics_dict.values())
|
| 116 |
-
return all_metrics
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
def parse_timestamps_from_text(text: str) -> List[Tuple[float, float]]:
|
| 120 |
-
"""
|
| 121 |
-
Function to parse timestamps from text.
|
| 122 |
-
Extracts timestamps in the format "start-end" where start and end are floats.
|
| 123 |
-
"""
|
| 124 |
-
# Regular expression to extract timestamps in the format "start-end"
|
| 125 |
-
pattern = r"(\d+\.\d+)-(\d+\.\d+)"
|
| 126 |
-
matches = re.findall(pattern, text)
|
| 127 |
-
events = [(float(start), float(end)) for start, end in matches]
|
| 128 |
-
return events
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
NatureLM/utils.py
CHANGED
|
@@ -25,9 +25,7 @@ import soundfile as sf
|
|
| 25 |
import torch
|
| 26 |
import torch.nn.functional as F
|
| 27 |
import torchaudio
|
| 28 |
-
from torch.utils.data import DataLoader
|
| 29 |
-
|
| 30 |
-
from NatureLM.dist_utils import get_rank, get_world_size
|
| 31 |
|
| 32 |
logger = logging.getLogger(__name__)
|
| 33 |
|
|
@@ -99,29 +97,6 @@ def now_as_str() -> str:
|
|
| 99 |
return datetime.now().strftime("%Y%m%d%H%M")
|
| 100 |
|
| 101 |
|
| 102 |
-
def get_dataloader(dataset, config, is_train=True, use_distributed=True):
|
| 103 |
-
if use_distributed:
|
| 104 |
-
sampler = DistributedSampler(dataset, shuffle=is_train, num_replicas=get_world_size(), rank=get_rank())
|
| 105 |
-
else:
|
| 106 |
-
sampler = None
|
| 107 |
-
|
| 108 |
-
loader = DataLoader(
|
| 109 |
-
dataset,
|
| 110 |
-
batch_size=config.batch_size_train if is_train else config.batch_size_eval,
|
| 111 |
-
num_workers=config.num_workers,
|
| 112 |
-
pin_memory=False,
|
| 113 |
-
sampler=sampler,
|
| 114 |
-
shuffle=sampler is None and is_train,
|
| 115 |
-
collate_fn=dataset.collater,
|
| 116 |
-
drop_last=is_train,
|
| 117 |
-
)
|
| 118 |
-
|
| 119 |
-
if is_train:
|
| 120 |
-
loader = IterLoader(loader, use_distributed=use_distributed)
|
| 121 |
-
|
| 122 |
-
return loader
|
| 123 |
-
|
| 124 |
-
|
| 125 |
def apply_to_sample(f, sample):
|
| 126 |
if len(sample) == 0:
|
| 127 |
return {}
|
|
|
|
| 25 |
import torch
|
| 26 |
import torch.nn.functional as F
|
| 27 |
import torchaudio
|
| 28 |
+
from torch.utils.data import DataLoader
|
|
|
|
|
|
|
| 29 |
|
| 30 |
logger = logging.getLogger(__name__)
|
| 31 |
|
|
|
|
| 97 |
return datetime.now().strftime("%Y%m%d%H%M")
|
| 98 |
|
| 99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
def apply_to_sample(f, sample):
|
| 101 |
if len(sample) == 0:
|
| 102 |
return {}
|
Space.yaml
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
sdk: gradio
|
| 2 |
python_version: 3.10
|
| 3 |
-
hardware:
|
|
|
|
| 1 |
sdk: gradio
|
| 2 |
python_version: 3.10
|
| 3 |
+
hardware: gpu
|
app.py
CHANGED
|
@@ -1,37 +1,86 @@
|
|
| 1 |
-
import
|
| 2 |
-
import
|
| 3 |
-
from collections import Counter
|
| 4 |
from pathlib import Path
|
| 5 |
-
from typing import
|
|
|
|
| 6 |
|
| 7 |
import gradio as gr
|
| 8 |
import torch
|
|
|
|
|
|
|
| 9 |
|
| 10 |
from NatureLM.config import Config
|
| 11 |
from NatureLM.models.NatureLM import NatureLM
|
| 12 |
-
from NatureLM.
|
| 13 |
import spaces
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
class ModelManager:
|
| 17 |
"""Manages model loading and state"""
|
| 18 |
-
|
| 19 |
def __init__(self):
|
| 20 |
self.model: Optional[NatureLM] = None
|
| 21 |
self.config: Optional[Config] = None
|
| 22 |
self.is_loaded = False
|
| 23 |
self.is_loading = False
|
| 24 |
self.load_failed = False
|
| 25 |
-
|
| 26 |
def check_availability(self) -> tuple[bool, str]:
|
| 27 |
"""Check if the model is available for download"""
|
| 28 |
try:
|
| 29 |
from huggingface_hub import model_info
|
|
|
|
| 30 |
info = model_info("EarthSpeciesProject/NatureLM-audio")
|
| 31 |
return True, "Model is available"
|
| 32 |
except Exception as e:
|
| 33 |
return False, f"Model not available: {str(e)}"
|
| 34 |
-
|
| 35 |
def reset_state(self):
|
| 36 |
"""Reset the model loading state to allow retrying after a failure"""
|
| 37 |
self.model = None
|
|
@@ -39,7 +88,7 @@ class ModelManager:
|
|
| 39 |
self.is_loading = False
|
| 40 |
self.load_failed = False
|
| 41 |
return self.get_status()
|
| 42 |
-
|
| 43 |
def get_status(self) -> str:
|
| 44 |
"""Get the current model loading status"""
|
| 45 |
if self.is_loaded:
|
|
@@ -50,34 +99,35 @@ class ModelManager:
|
|
| 50 |
return "❌ Model failed to load. Please check the configuration."
|
| 51 |
else:
|
| 52 |
return "⏳ Ready to load model on first use"
|
| 53 |
-
|
| 54 |
def load_model(self) -> Optional[NatureLM]:
|
| 55 |
"""Load the model if needed"""
|
| 56 |
if self.is_loaded:
|
| 57 |
return self.model
|
| 58 |
-
|
| 59 |
if self.is_loading or self.load_failed:
|
| 60 |
return None
|
| 61 |
-
|
| 62 |
try:
|
| 63 |
self.is_loading = True
|
| 64 |
print("Loading model...")
|
| 65 |
-
|
| 66 |
# Check if model is available first
|
| 67 |
available, message = self.check_availability()
|
| 68 |
if not available:
|
| 69 |
raise Exception(f"Model not available: {message}")
|
| 70 |
-
|
| 71 |
model = NatureLM.from_pretrained("EarthSpeciesProject/NatureLM-audio")
|
| 72 |
-
model.to("
|
| 73 |
model.eval()
|
| 74 |
-
|
| 75 |
-
|
|
|
|
| 76 |
self.is_loaded = True
|
| 77 |
self.is_loading = False
|
| 78 |
print("Model loaded successfully!")
|
| 79 |
-
return
|
| 80 |
-
|
| 81 |
except Exception as e:
|
| 82 |
print(f"Error loading model: {e}")
|
| 83 |
self.is_loading = False
|
|
@@ -88,12 +138,44 @@ class ModelManager:
|
|
| 88 |
# Global model manager instance
|
| 89 |
model_manager = ModelManager()
|
| 90 |
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
@spaces.GPU
|
| 93 |
-
def prompt_lm(
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
model = model_manager.load_model()
|
| 96 |
-
|
| 97 |
if model is None:
|
| 98 |
if model_manager.is_loading:
|
| 99 |
return "🔄 Loading model... This may take a few minutes on first use. Please try again in a moment."
|
|
@@ -101,284 +183,63 @@ def prompt_lm(audios: list[str], messages: list[dict[str, str]]) -> str:
|
|
| 101 |
return "❌ Model failed to load. This could be due to:\n• No internet connection\n• Insufficient disk space\n• Model repository access issues\n\nPlease check your connection and try again using the retry button."
|
| 102 |
else:
|
| 103 |
return "Demo mode: Model not loaded. Please check the model configuration."
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
r"<\|start_header_id\|>system<\|end_header_id\|>\n\nCutting Knowledge Date: [^\n]+\nToday Date: [^\n]+\n\n<\|eot_id\|>",
|
| 113 |
-
"",
|
| 114 |
-
prompt_text,
|
| 115 |
-
)
|
| 116 |
-
prompt_text = re.sub("\\n", r"\\n", prompt_text)
|
| 117 |
-
|
| 118 |
-
print(f"{prompt_text=}")
|
| 119 |
-
with torch.cuda.amp.autocast(dtype=torch.float16):
|
| 120 |
-
llm_answer = model.generate(samples, model_manager.config.generate, prompts=[prompt_text])
|
| 121 |
-
return llm_answer[0]
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
def _multimodal_textbox_factory():
|
| 125 |
-
return gr.MultimodalTextbox(
|
| 126 |
-
value=None,
|
| 127 |
-
interactive=True,
|
| 128 |
-
sources="microphone",
|
| 129 |
-
placeholder="Enter message...",
|
| 130 |
-
show_label=False,
|
| 131 |
-
autofocus=True,
|
| 132 |
-
submit_btn="Send"
|
| 133 |
)
|
|
|
|
| 134 |
|
| 135 |
|
| 136 |
def user_message(content):
|
| 137 |
return {"role": "user", "content": content}
|
| 138 |
|
| 139 |
|
| 140 |
-
def
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
def combine_model_inputs(msgs: list[dict[str, str]]) -> dict[str, list[str]]:
|
| 149 |
-
messages = []
|
| 150 |
-
files = []
|
| 151 |
-
for msg in msgs:
|
| 152 |
-
print(msg, messages, files)
|
| 153 |
-
match msg:
|
| 154 |
-
case {"content": (path,)}:
|
| 155 |
-
messages.append({"role": msg["role"], "content": "<Audio><AudioHere></Audio> "})
|
| 156 |
-
files.append(path)
|
| 157 |
-
case _:
|
| 158 |
-
messages.append(msg)
|
| 159 |
-
|
| 160 |
-
# Join consecutive messages from the same role
|
| 161 |
-
joined_messages = []
|
| 162 |
-
for msg in messages:
|
| 163 |
-
if joined_messages and joined_messages[-1]["role"] == msg["role"]:
|
| 164 |
-
joined_messages[-1]["content"] += msg["content"]
|
| 165 |
-
else:
|
| 166 |
-
joined_messages.append(msg)
|
| 167 |
-
|
| 168 |
-
return {"messages": joined_messages, "files": files}
|
| 169 |
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
response = prompt_lm(combined_inputs["files"], combined_inputs["messages"])
|
| 175 |
-
history.append({"role": "assistant", "content": response})
|
| 176 |
-
return history
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
def _chat_tab(examples):
|
| 180 |
-
# Status indicator
|
| 181 |
-
status_text = gr.Textbox(
|
| 182 |
-
value=model_manager.get_status(),
|
| 183 |
-
label="Model Status",
|
| 184 |
-
interactive=False,
|
| 185 |
-
visible=True
|
| 186 |
)
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
resizeable=True
|
| 195 |
-
)
|
| 196 |
-
|
| 197 |
-
chat_input = _multimodal_textbox_factory()
|
| 198 |
-
send_all = gr.Button("Send all", elem_id="send-all")
|
| 199 |
-
clear_button = gr.ClearButton(components=[chatbot, chat_input], visible=False)
|
| 200 |
-
|
| 201 |
-
chat_input.submit(add_message, [chatbot, chat_input], [chatbot, chat_input])
|
| 202 |
-
bot_msg = send_all.click(
|
| 203 |
-
bot_response,
|
| 204 |
-
[chatbot],
|
| 205 |
-
[chatbot],
|
| 206 |
-
api_name="bot_response",
|
| 207 |
-
)
|
| 208 |
-
|
| 209 |
-
# Update status after bot response
|
| 210 |
-
bot_msg.then(lambda: model_manager.get_status(), None, [status_text])
|
| 211 |
-
bot_msg.then(lambda: gr.ClearButton(visible=True), None, [clear_button])
|
| 212 |
-
clear_button.click(lambda: gr.ClearButton(visible=False), None, [clear_button])
|
| 213 |
-
|
| 214 |
-
gr.Examples(
|
| 215 |
-
list(examples.values()),
|
| 216 |
-
chatbot,
|
| 217 |
-
chatbot,
|
| 218 |
-
example_labels=list(examples.keys()),
|
| 219 |
-
examples_per_page=20,
|
| 220 |
-
)
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
def summarize_batch_results(results):
|
| 224 |
-
summary = Counter(results)
|
| 225 |
-
summary_str = "\n".join(f"{k}: {v}" for k, v in summary.most_common())
|
| 226 |
-
return summary_str
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
def run_batch_inference(files, task, progress=gr.Progress()) -> str:
|
| 230 |
-
model = model_manager.load_model()
|
| 231 |
-
if model is None:
|
| 232 |
-
if model_manager.is_loading:
|
| 233 |
-
return "🔄 Loading model... This may take a few minutes on first use. Please try again in a moment."
|
| 234 |
-
elif model_manager.load_failed:
|
| 235 |
-
return "❌ Model failed to load. This could be due to:\n• No internet connection\n• Insufficient disk space\n• Model repository access issues\n\nPlease check your connection and try again."
|
| 236 |
-
else:
|
| 237 |
-
return "Demo mode: Model not loaded. Please check the model configuration."
|
| 238 |
-
|
| 239 |
-
outputs = []
|
| 240 |
-
prompt = "<Audio><AudioHere></Audio> " + task
|
| 241 |
-
|
| 242 |
-
for file in progress.tqdm(files):
|
| 243 |
-
outputs.append(prompt_lm([file], [{"role": "user", "content": prompt}]))
|
| 244 |
-
|
| 245 |
-
batch_summary: str = summarize_batch_results(outputs)
|
| 246 |
-
report = f"Batch summary:\n{batch_summary}\n\n"
|
| 247 |
-
return report
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
def multi_extension_glob_mask(mask_base, *extensions):
|
| 251 |
-
mask_ext = ["[{}]".format("".join(set(c))) for c in zip(*extensions)]
|
| 252 |
-
if not mask_ext or len(set(len(e) for e in extensions)) > 1:
|
| 253 |
-
mask_ext.append("*")
|
| 254 |
-
return mask_base + "".join(mask_ext)
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
def _batch_tab(file_selection: Literal["upload", "explorer"] = "upload"):
|
| 258 |
-
if file_selection == "explorer":
|
| 259 |
-
files = gr.FileExplorer(
|
| 260 |
-
glob=multi_extension_glob_mask("**.", "mp3", "flac", "wav"),
|
| 261 |
-
label="Select audio files",
|
| 262 |
-
file_count="multiple",
|
| 263 |
)
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
process_btn = gr.Button("Process")
|
| 269 |
-
output = gr.TextArea()
|
| 270 |
-
|
| 271 |
-
process_btn.click(
|
| 272 |
-
run_batch_inference,
|
| 273 |
-
[files, task],
|
| 274 |
-
[output],
|
| 275 |
-
)
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
def to_raven_format(outputs: dict[int, str], chunk_len: int = 10) -> str:
|
| 279 |
-
def get_line(row, start, end, annotation):
|
| 280 |
-
return f"{row}\tSpectrogram 1\t1\t{start}\t{end}\t0\t8000\t{annotation}"
|
| 281 |
-
|
| 282 |
-
raven_output = ["Selection\tView\tChannel\tBegin Time (s)\tEnd Time (s)\tLow Freq (Hz)\tHigh Freq (Hz)\tAnnotation"]
|
| 283 |
-
current_offset = 0
|
| 284 |
-
last_label = ""
|
| 285 |
-
row = 1
|
| 286 |
-
|
| 287 |
-
for offset, label in sorted(outputs.items()):
|
| 288 |
-
if label != last_label and last_label:
|
| 289 |
-
raven_output.append(get_line(row, current_offset, offset, last_label))
|
| 290 |
-
current_offset = offset
|
| 291 |
-
row += 1
|
| 292 |
-
if not last_label:
|
| 293 |
-
current_offset = offset
|
| 294 |
-
if label != "None":
|
| 295 |
-
last_label = label
|
| 296 |
else:
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
return "🔄 Loading model... This may take a few minutes on first use. Please try again in a moment.", None
|
| 308 |
-
|
| 309 |
-
# Check if model failed to load
|
| 310 |
-
if model_manager.load_failed:
|
| 311 |
-
return "❌ Model failed to load. This could be due to:\n• No internet connection\n• Insufficient disk space\n• Model repository access issues\n\nPlease refresh the page to try again.", None
|
| 312 |
-
|
| 313 |
-
model = model_manager.load_model()
|
| 314 |
-
if model is None:
|
| 315 |
-
return "Demo mode: Model not loaded. Please check the model configuration.", None
|
| 316 |
-
|
| 317 |
-
cuda_enabled = torch.cuda.is_available()
|
| 318 |
-
outputs = {}
|
| 319 |
-
offset = 0
|
| 320 |
-
|
| 321 |
-
prompt = f"<Audio><AudioHere></Audio> {task}"
|
| 322 |
-
prompt = model_manager.config.model.prompt_template.format(prompt)
|
| 323 |
-
|
| 324 |
-
for batch in progress.tqdm(generate_sample_batches(file, cuda_enabled, chunk_len=chunk_len, hop_len=hop_len)):
|
| 325 |
-
prompt_strs = [prompt] * len(batch["audio_chunk_sizes"])
|
| 326 |
-
with torch.cuda.amp.autocast(dtype=torch.float16):
|
| 327 |
-
llm_answers = model.generate(batch, model_manager.config.generate, prompts=prompt_strs)
|
| 328 |
-
for answer in llm_answers:
|
| 329 |
-
outputs[offset] = answer
|
| 330 |
-
offset += hop_len
|
| 331 |
-
|
| 332 |
-
report = f"Number of chunks: {len(outputs)}\n\n"
|
| 333 |
-
for offset in sorted(outputs.keys()):
|
| 334 |
-
report += f"{offset:02d}s:\t{outputs[offset]}\n"
|
| 335 |
-
|
| 336 |
-
raven_output = to_raven_format(outputs, chunk_len=chunk_len)
|
| 337 |
-
with tempfile.NamedTemporaryFile(mode="w", prefix="raven-", suffix=".txt", delete=False) as f:
|
| 338 |
-
f.write(raven_output)
|
| 339 |
-
raven_file = f.name
|
| 340 |
-
|
| 341 |
-
return report, raven_file
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
def _long_recording_tab():
|
| 345 |
-
audio_input = gr.Audio(label="Upload audio file", type="filepath")
|
| 346 |
-
task = gr.Dropdown(
|
| 347 |
-
[
|
| 348 |
-
"What are the common names for the species in the audio, if any?",
|
| 349 |
-
"Caption the audio.",
|
| 350 |
-
"Caption the audio, using the scientific name for any animal species.",
|
| 351 |
-
"Caption the audio, using the common name for any animal species.",
|
| 352 |
-
"What is the scientific name for the focal species in the audio?",
|
| 353 |
-
"What is the common name for the focal species in the audio?",
|
| 354 |
-
"What is the family of the focal species in the audio?",
|
| 355 |
-
"What is the genus of the focal species in the audio?",
|
| 356 |
-
"What is the taxonomic name of the focal species in the audio?",
|
| 357 |
-
"What call types are heard from the focal species in the audio?",
|
| 358 |
-
"What is the life stage of the focal species in the audio?",
|
| 359 |
-
],
|
| 360 |
-
label="Tasks",
|
| 361 |
-
allow_custom_value=True,
|
| 362 |
-
)
|
| 363 |
-
with gr.Accordion("Advanced options", open=False):
|
| 364 |
-
hop_len = gr.Slider(1, 10, 5, label="Hop length (seconds)", step=1)
|
| 365 |
-
chunk_len = gr.Slider(1, 10, 10, label="Chunk length (seconds)", step=1)
|
| 366 |
-
process_btn = gr.Button("Process")
|
| 367 |
-
output = gr.TextArea()
|
| 368 |
-
download_raven = gr.DownloadButton("Download Raven file")
|
| 369 |
-
|
| 370 |
-
process_btn.click(
|
| 371 |
-
_run_long_recording_inference,
|
| 372 |
-
[audio_input, task, chunk_len, hop_len],
|
| 373 |
-
[output, download_raven],
|
| 374 |
-
)
|
| 375 |
|
| 376 |
|
| 377 |
def main(
|
| 378 |
assets_dir: Path,
|
| 379 |
cfg_path: str | Path,
|
| 380 |
options: list[str] = [],
|
| 381 |
-
device: str = "cuda",
|
| 382 |
):
|
| 383 |
# Load configuration
|
| 384 |
try:
|
|
@@ -394,7 +255,7 @@ def main(
|
|
| 394 |
if not assets_dir.exists():
|
| 395 |
print(f"Warning: Assets directory {assets_dir} does not exist")
|
| 396 |
assets_dir.mkdir(exist_ok=True)
|
| 397 |
-
|
| 398 |
# Create placeholder audio files if they don't exist
|
| 399 |
laz_audio = assets_dir / "Lazuli_Bunting_yell-YELLLAZB20160625SM303143.mp3"
|
| 400 |
frog_audio = assets_dir / "nri-GreenTreeFrogEvergladesNP.mp3"
|
|
@@ -411,7 +272,9 @@ def main(
|
|
| 411 |
"Caption the audio (Green Tree Frog)": [
|
| 412 |
[
|
| 413 |
user_message({"path": str(frog_audio)}),
|
| 414 |
-
user_message(
|
|
|
|
|
|
|
| 415 |
]
|
| 416 |
],
|
| 417 |
"Caption the audio (American Robin)": [
|
|
@@ -428,17 +291,31 @@ def main(
|
|
| 428 |
],
|
| 429 |
}
|
| 430 |
|
| 431 |
-
with gr.Blocks(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 432 |
header = gr.HTML("""
|
| 433 |
<div style="display: flex; align-items: center; gap: 12px;"><h2 style="margin: 0;">NatureLM-audio<span style="font-size: 0.55em; color: #28a745; background: #e6f4ea; padding: 2px 6px; border-radius: 4px; margin-left: 8px; display: inline-block; vertical-align: top;">BETA</span></h2></div>
|
| 434 |
|
| 435 |
""")
|
| 436 |
-
|
| 437 |
with gr.Tabs():
|
| 438 |
with gr.Tab("Analyze Audio"):
|
| 439 |
-
uploaded_audio = gr.State()
|
| 440 |
-
|
| 441 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 442 |
<div style="
|
| 443 |
background: transparent;
|
| 444 |
border: 1px solid #e5e7eb;
|
|
@@ -476,45 +353,102 @@ def main(
|
|
| 476 |
onmouseout="this.style.background='#3b82f6';"
|
| 477 |
>View Tutorial</a>
|
| 478 |
</div>
|
| 479 |
-
""",
|
|
|
|
|
|
|
| 480 |
with gr.Column(visible=True) as upload_section:
|
| 481 |
-
audio_input = gr.Audio(
|
| 482 |
type="filepath",
|
| 483 |
-
container=True,
|
| 484 |
-
interactive=True,
|
| 485 |
-
sources=[
|
| 486 |
-
|
| 487 |
with gr.Group(visible=False) as chat:
|
| 488 |
-
chatbot = gr.Chatbot(
|
| 489 |
-
elem_id="chatbot",
|
| 490 |
-
type="messages",
|
|
|
|
| 491 |
render_markdown=False,
|
| 492 |
-
feedback_options=[
|
| 493 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 494 |
)
|
| 495 |
-
chat_input = _multimodal_textbox_factory()
|
| 496 |
-
send_all = gr.Button("Send all")
|
| 497 |
|
| 498 |
-
|
| 499 |
def start_chat_interface(audio_path):
|
| 500 |
-
return (
|
| 501 |
-
gr.update(visible=False),
|
| 502 |
gr.update(visible=True), # show upload section
|
| 503 |
-
gr.update(visible=True), # show chat box
|
| 504 |
)
|
| 505 |
|
| 506 |
audio_input.change(
|
| 507 |
fn=start_chat_interface,
|
| 508 |
inputs=[audio_input],
|
| 509 |
-
outputs=[onboarding_message, upload_section, chat]
|
| 510 |
)
|
| 511 |
|
| 512 |
-
chat_input.submit(
|
| 513 |
-
|
|
|
|
|
|
|
|
|
|
| 514 |
|
|
|
|
|
|
|
|
|
|
| 515 |
|
| 516 |
with gr.Tab("Sample Library"):
|
| 517 |
-
gr.Markdown("## Sample Library\n\nExplore example audio files below.")
|
| 518 |
gr.Examples(
|
| 519 |
list(examples.values()),
|
| 520 |
chatbot,
|
|
@@ -523,10 +457,10 @@ def main(
|
|
| 523 |
examples_per_page=20,
|
| 524 |
)
|
| 525 |
with gr.Tab("💡 Help"):
|
| 526 |
-
gr.Markdown("## User Guide")
|
| 527 |
-
gr.Markdown("## Share Feedback")
|
| 528 |
-
gr.Markdown("## FAQs")
|
| 529 |
-
|
| 530 |
app.css = """
|
| 531 |
.welcome-banner {
|
| 532 |
background: transparent !important;
|
|
@@ -550,7 +484,7 @@ def main(
|
|
| 550 |
_batch_tab()
|
| 551 |
with gr.Tab("Long Recording"):
|
| 552 |
_long_recording_tab() """
|
| 553 |
-
|
| 554 |
return app
|
| 555 |
|
| 556 |
|
|
@@ -559,8 +493,7 @@ app = main(
|
|
| 559 |
assets_dir=Path("assets"),
|
| 560 |
cfg_path=Path("configs/inference.yml"),
|
| 561 |
options=[],
|
| 562 |
-
device="cuda",
|
| 563 |
)
|
| 564 |
|
| 565 |
if __name__ == "__main__":
|
| 566 |
-
app.launch()
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
import numpy as np
|
|
|
|
| 3 |
from pathlib import Path
|
| 4 |
+
from typing import Optional
|
| 5 |
+
from collections import Counter
|
| 6 |
|
| 7 |
import gradio as gr
|
| 8 |
import torch
|
| 9 |
+
import torchaudio
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
|
| 12 |
from NatureLM.config import Config
|
| 13 |
from NatureLM.models.NatureLM import NatureLM
|
| 14 |
+
from NatureLM.infer import Pipeline
|
| 15 |
import spaces
|
| 16 |
|
| 17 |
+
warnings.filterwarnings("ignore")
|
| 18 |
+
SAMPLE_RATE = 16000 # Default sample rate for NatureLM-audio
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_spectrogram(audio: torch.Tensor) -> plt.Figure:
|
| 22 |
+
"""Generate a spectrogram from the audio tensor."""
|
| 23 |
+
spectrogram = torchaudio.transforms.Spectrogram(n_fft=1024)(audio)
|
| 24 |
+
spectrogram = spectrogram.numpy()[0].squeeze()
|
| 25 |
+
# Convert to matplotlib figure with imshow
|
| 26 |
+
fig, ax = plt.subplots(figsize=(13, 5))
|
| 27 |
+
ax.imshow(np.log(spectrogram + 1e-3), aspect="auto", origin="lower", cmap="viridis")
|
| 28 |
+
ax.set_title("Spectrogram")
|
| 29 |
+
ax.set_xlabel("Time")
|
| 30 |
+
# Set x ticks to reflect 0 to audio duration seconds
|
| 31 |
+
if audio.dim() > 1:
|
| 32 |
+
duration = audio.size(1) / SAMPLE_RATE
|
| 33 |
+
else:
|
| 34 |
+
duration = audio.size(0) / SAMPLE_RATE
|
| 35 |
+
|
| 36 |
+
ax.set_xticks([0, spectrogram.shape[1]])
|
| 37 |
+
ax.set_xticklabels(["0s", f"{duration:.2f}s"])
|
| 38 |
+
ax.set_ylabel("Frequency")
|
| 39 |
+
# Set y ticks to reflect 0 to nyquist frequency (sample_rate/2)
|
| 40 |
+
nyquist_freq = SAMPLE_RATE / 2
|
| 41 |
+
ax.set_yticks(
|
| 42 |
+
[
|
| 43 |
+
0,
|
| 44 |
+
spectrogram.shape[0] // 4,
|
| 45 |
+
spectrogram.shape[0] // 2,
|
| 46 |
+
3 * spectrogram.shape[0] // 4,
|
| 47 |
+
spectrogram.shape[0] - 1,
|
| 48 |
+
]
|
| 49 |
+
)
|
| 50 |
+
ax.set_yticklabels(
|
| 51 |
+
[
|
| 52 |
+
"0 Hz",
|
| 53 |
+
f"{nyquist_freq / 4:.0f} Hz",
|
| 54 |
+
f"{nyquist_freq / 2:.0f} Hz",
|
| 55 |
+
f"{3 * nyquist_freq / 4:.0f} Hz",
|
| 56 |
+
f"{nyquist_freq:.0f} Hz",
|
| 57 |
+
]
|
| 58 |
+
)
|
| 59 |
+
fig.tight_layout()
|
| 60 |
+
|
| 61 |
+
return fig
|
| 62 |
+
|
| 63 |
|
| 64 |
class ModelManager:
|
| 65 |
"""Manages model loading and state"""
|
| 66 |
+
|
| 67 |
def __init__(self):
|
| 68 |
self.model: Optional[NatureLM] = None
|
| 69 |
self.config: Optional[Config] = None
|
| 70 |
self.is_loaded = False
|
| 71 |
self.is_loading = False
|
| 72 |
self.load_failed = False
|
| 73 |
+
|
| 74 |
def check_availability(self) -> tuple[bool, str]:
|
| 75 |
"""Check if the model is available for download"""
|
| 76 |
try:
|
| 77 |
from huggingface_hub import model_info
|
| 78 |
+
|
| 79 |
info = model_info("EarthSpeciesProject/NatureLM-audio")
|
| 80 |
return True, "Model is available"
|
| 81 |
except Exception as e:
|
| 82 |
return False, f"Model not available: {str(e)}"
|
| 83 |
+
|
| 84 |
def reset_state(self):
|
| 85 |
"""Reset the model loading state to allow retrying after a failure"""
|
| 86 |
self.model = None
|
|
|
|
| 88 |
self.is_loading = False
|
| 89 |
self.load_failed = False
|
| 90 |
return self.get_status()
|
| 91 |
+
|
| 92 |
def get_status(self) -> str:
|
| 93 |
"""Get the current model loading status"""
|
| 94 |
if self.is_loaded:
|
|
|
|
| 99 |
return "❌ Model failed to load. Please check the configuration."
|
| 100 |
else:
|
| 101 |
return "⏳ Ready to load model on first use"
|
| 102 |
+
|
| 103 |
def load_model(self) -> Optional[NatureLM]:
|
| 104 |
"""Load the model if needed"""
|
| 105 |
if self.is_loaded:
|
| 106 |
return self.model
|
| 107 |
+
|
| 108 |
if self.is_loading or self.load_failed:
|
| 109 |
return None
|
| 110 |
+
|
| 111 |
try:
|
| 112 |
self.is_loading = True
|
| 113 |
print("Loading model...")
|
| 114 |
+
|
| 115 |
# Check if model is available first
|
| 116 |
available, message = self.check_availability()
|
| 117 |
if not available:
|
| 118 |
raise Exception(f"Model not available: {message}")
|
| 119 |
+
|
| 120 |
model = NatureLM.from_pretrained("EarthSpeciesProject/NatureLM-audio")
|
| 121 |
+
model.to("cpu")
|
| 122 |
model.eval()
|
| 123 |
+
|
| 124 |
+
pipe = Pipeline(model)
|
| 125 |
+
self.model = pipe
|
| 126 |
self.is_loaded = True
|
| 127 |
self.is_loading = False
|
| 128 |
print("Model loaded successfully!")
|
| 129 |
+
return pipe
|
| 130 |
+
|
| 131 |
except Exception as e:
|
| 132 |
print(f"Error loading model: {e}")
|
| 133 |
self.is_loading = False
|
|
|
|
| 138 |
# Global model manager instance
|
| 139 |
model_manager = ModelManager()
|
| 140 |
|
| 141 |
+
|
| 142 |
+
def take_majority_vote(results: list[list[dict]]) -> list[str]:
|
| 143 |
+
"""For each audio file, take the majority vote of the labels across all windows"""
|
| 144 |
+
outputs = []
|
| 145 |
+
for result in results:
|
| 146 |
+
predictions = [window["prediction"] for window in result]
|
| 147 |
+
if not predictions:
|
| 148 |
+
continue
|
| 149 |
+
# Count occurrences of each label
|
| 150 |
+
counts = Counter(predictions)
|
| 151 |
+
# Find the most common label
|
| 152 |
+
most_common_label, _ = counts.most_common(1)[0]
|
| 153 |
+
outputs.append(most_common_label)
|
| 154 |
+
|
| 155 |
+
return outputs
|
| 156 |
+
|
| 157 |
+
|
| 158 |
@spaces.GPU
|
| 159 |
+
def prompt_lm(
|
| 160 |
+
audios: list[str],
|
| 161 |
+
queries: list[str] | str,
|
| 162 |
+
window_length_seconds: float = 10.0,
|
| 163 |
+
hop_length_seconds: float = 10.0,
|
| 164 |
+
progress=gr.Progress(),
|
| 165 |
+
) -> list[str]:
|
| 166 |
+
"""Generate response using the model
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
audios (list[str]): List of audio file paths
|
| 170 |
+
queries (list[str] | str): Query or list of queries to process
|
| 171 |
+
window_length_seconds (float): Length of the window for processing audio
|
| 172 |
+
hop_length_seconds (float): Hop length for processing audio
|
| 173 |
+
|
| 174 |
+
Returns:
|
| 175 |
+
list[str]: List of generated responses for each audio-query pair
|
| 176 |
+
"""
|
| 177 |
model = model_manager.load_model()
|
| 178 |
+
|
| 179 |
if model is None:
|
| 180 |
if model_manager.is_loading:
|
| 181 |
return "🔄 Loading model... This may take a few minutes on first use. Please try again in a moment."
|
|
|
|
| 183 |
return "❌ Model failed to load. This could be due to:\n• No internet connection\n• Insufficient disk space\n• Model repository access issues\n\nPlease check your connection and try again using the retry button."
|
| 184 |
else:
|
| 185 |
return "Demo mode: Model not loaded. Please check the model configuration."
|
| 186 |
+
|
| 187 |
+
results: list[list[dict]] = model(
|
| 188 |
+
audios,
|
| 189 |
+
queries,
|
| 190 |
+
window_length_seconds=window_length_seconds,
|
| 191 |
+
hop_length_seconds=hop_length_seconds,
|
| 192 |
+
input_sample_rate=None,
|
| 193 |
+
progress_bar=progress,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
)
|
| 195 |
+
return results
|
| 196 |
|
| 197 |
|
| 198 |
def user_message(content):
|
| 199 |
return {"role": "user", "content": content}
|
| 200 |
|
| 201 |
|
| 202 |
+
def add_message_and_get_response(
|
| 203 |
+
chatbot_history: list[dict], audio_input: str, chat_input: str
|
| 204 |
+
) -> tuple[list[dict], str]:
|
| 205 |
+
"""Add user message to chat and get model response"""
|
| 206 |
+
# Load audio with torchaudio and compute spectrogram
|
| 207 |
+
audio_tensor, sample_rate = torchaudio.load(audio_input)
|
| 208 |
+
duration = audio_tensor.size(1) / sample_rate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
|
| 210 |
+
spectrogram_fig = get_spectrogram(audio_tensor)
|
| 211 |
+
# Add gr.Plot to chatbot history
|
| 212 |
+
chatbot_history.append(
|
| 213 |
+
{"role": "user", "content": gr.Plot(spectrogram_fig, label="Spectrogram")}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
)
|
| 215 |
+
# Get response
|
| 216 |
+
try:
|
| 217 |
+
response = prompt_lm(
|
| 218 |
+
audios=[audio_input],
|
| 219 |
+
queries=[chat_input],
|
| 220 |
+
window_length_seconds=duration,
|
| 221 |
+
hop_length_seconds=duration,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
)
|
| 223 |
+
# get first item
|
| 224 |
+
if isinstance(response, list) and len(response) > 0:
|
| 225 |
+
response = response[0][0]["prediction"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
else:
|
| 227 |
+
response = "No response generated."
|
| 228 |
+
except Exception as e:
|
| 229 |
+
print(f"Error generating response: {e}")
|
| 230 |
+
response = "Error generating response. Please try again."
|
| 231 |
+
|
| 232 |
+
# Add user message to chat history
|
| 233 |
+
chatbot_history.append({"role": "user", "content": "Q: " + chat_input})
|
| 234 |
+
# Add model response to chat history
|
| 235 |
+
chatbot_history.append({"role": "assistant", "content": response})
|
| 236 |
+
return chatbot_history, ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
|
| 238 |
|
| 239 |
def main(
|
| 240 |
assets_dir: Path,
|
| 241 |
cfg_path: str | Path,
|
| 242 |
options: list[str] = [],
|
|
|
|
| 243 |
):
|
| 244 |
# Load configuration
|
| 245 |
try:
|
|
|
|
| 255 |
if not assets_dir.exists():
|
| 256 |
print(f"Warning: Assets directory {assets_dir} does not exist")
|
| 257 |
assets_dir.mkdir(exist_ok=True)
|
| 258 |
+
|
| 259 |
# Create placeholder audio files if they don't exist
|
| 260 |
laz_audio = assets_dir / "Lazuli_Bunting_yell-YELLLAZB20160625SM303143.mp3"
|
| 261 |
frog_audio = assets_dir / "nri-GreenTreeFrogEvergladesNP.mp3"
|
|
|
|
| 272 |
"Caption the audio (Green Tree Frog)": [
|
| 273 |
[
|
| 274 |
user_message({"path": str(frog_audio)}),
|
| 275 |
+
user_message(
|
| 276 |
+
"Caption the audio, using the common name for any animal species."
|
| 277 |
+
),
|
| 278 |
]
|
| 279 |
],
|
| 280 |
"Caption the audio (American Robin)": [
|
|
|
|
| 291 |
],
|
| 292 |
}
|
| 293 |
|
| 294 |
+
with gr.Blocks(
|
| 295 |
+
title="NatureLM-audio",
|
| 296 |
+
theme=gr.themes.Base(
|
| 297 |
+
primary_hue="blue", font=[gr.themes.GoogleFont("Noto Sans")]
|
| 298 |
+
),
|
| 299 |
+
) as app:
|
| 300 |
header = gr.HTML("""
|
| 301 |
<div style="display: flex; align-items: center; gap: 12px;"><h2 style="margin: 0;">NatureLM-audio<span style="font-size: 0.55em; color: #28a745; background: #e6f4ea; padding: 2px 6px; border-radius: 4px; margin-left: 8px; display: inline-block; vertical-align: top;">BETA</span></h2></div>
|
| 302 |
|
| 303 |
""")
|
| 304 |
+
|
| 305 |
with gr.Tabs():
|
| 306 |
with gr.Tab("Analyze Audio"):
|
| 307 |
+
uploaded_audio = gr.State()
|
| 308 |
+
# Status indicator
|
| 309 |
+
# status_text = gr.Textbox(
|
| 310 |
+
# value=model_manager.get_status(),
|
| 311 |
+
# label="Model Status",
|
| 312 |
+
# interactive=False,
|
| 313 |
+
# visible=True,
|
| 314 |
+
# )
|
| 315 |
+
|
| 316 |
+
with gr.Column(visible=True) as onboarding_message:
|
| 317 |
+
gr.HTML(
|
| 318 |
+
"""
|
| 319 |
<div style="
|
| 320 |
background: transparent;
|
| 321 |
border: 1px solid #e5e7eb;
|
|
|
|
| 353 |
onmouseout="this.style.background='#3b82f6';"
|
| 354 |
>View Tutorial</a>
|
| 355 |
</div>
|
| 356 |
+
""",
|
| 357 |
+
padding=False,
|
| 358 |
+
)
|
| 359 |
with gr.Column(visible=True) as upload_section:
|
| 360 |
+
audio_input = gr.Audio(
|
| 361 |
type="filepath",
|
| 362 |
+
container=True,
|
| 363 |
+
interactive=True,
|
| 364 |
+
sources=["upload"],
|
| 365 |
+
)
|
| 366 |
with gr.Group(visible=False) as chat:
|
| 367 |
+
chatbot = gr.Chatbot(
|
| 368 |
+
elem_id="chatbot",
|
| 369 |
+
type="messages",
|
| 370 |
+
label="Chat",
|
| 371 |
render_markdown=False,
|
| 372 |
+
feedback_options=[
|
| 373 |
+
"like",
|
| 374 |
+
"dislike",
|
| 375 |
+
"wrong species",
|
| 376 |
+
"incorrect response",
|
| 377 |
+
"other",
|
| 378 |
+
],
|
| 379 |
+
resizeable=True,
|
| 380 |
+
)
|
| 381 |
+
gr.Markdown("### Your Query")
|
| 382 |
+
task_dropdown = gr.Dropdown(
|
| 383 |
+
[
|
| 384 |
+
"What are the common names for the species in the audio, if any?",
|
| 385 |
+
"Caption the audio.",
|
| 386 |
+
"Caption the audio, using the scientific name for any animal species.",
|
| 387 |
+
"Caption the audio, using the common name for any animal species.",
|
| 388 |
+
"What is the scientific name for the focal species in the audio?",
|
| 389 |
+
"What is the common name for the focal species in the audio?",
|
| 390 |
+
"What is the family of the focal species in the audio?",
|
| 391 |
+
"What is the genus of the focal species in the audio?",
|
| 392 |
+
"What is the taxonomic name of the focal species in the audio?",
|
| 393 |
+
"What call types are heard from the focal species in the audio?",
|
| 394 |
+
"What is the life stage of the focal species in the audio?",
|
| 395 |
+
],
|
| 396 |
+
label="Pre-configured Tasks",
|
| 397 |
+
allow_custom_value=True,
|
| 398 |
+
info="Select a task or enter a custom query below",
|
| 399 |
+
)
|
| 400 |
+
chat_input = gr.Textbox(
|
| 401 |
+
placeholder="e.g. 'Caption this audio'...",
|
| 402 |
+
type="text",
|
| 403 |
+
label="Query",
|
| 404 |
+
lines=2,
|
| 405 |
+
show_label=True,
|
| 406 |
+
container=False,
|
| 407 |
+
submit_btn="Send",
|
| 408 |
+
elem_id="chat-input",
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
# if task_dropdown is selected, set chat_input to that value
|
| 412 |
+
def set_query(task):
|
| 413 |
+
if task:
|
| 414 |
+
return gr.update(value=task)
|
| 415 |
+
return gr.update(value="")
|
| 416 |
+
|
| 417 |
+
task_dropdown.change(
|
| 418 |
+
fn=set_query,
|
| 419 |
+
inputs=[task_dropdown],
|
| 420 |
+
outputs=[chat_input],
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
clear_button = gr.ClearButton(
|
| 424 |
+
components=[chatbot, chat_input, audio_input], visible=False
|
| 425 |
)
|
|
|
|
|
|
|
| 426 |
|
|
|
|
| 427 |
def start_chat_interface(audio_path):
|
| 428 |
+
return (
|
| 429 |
+
gr.update(visible=False), # hide onboarding message
|
| 430 |
gr.update(visible=True), # show upload section
|
| 431 |
+
gr.update(visible=True), # show chat box
|
| 432 |
)
|
| 433 |
|
| 434 |
audio_input.change(
|
| 435 |
fn=start_chat_interface,
|
| 436 |
inputs=[audio_input],
|
| 437 |
+
outputs=[onboarding_message, upload_section, chat],
|
| 438 |
)
|
| 439 |
|
| 440 |
+
chat_input.submit(
|
| 441 |
+
add_message_and_get_response,
|
| 442 |
+
inputs=[chatbot, audio_input, chat_input],
|
| 443 |
+
outputs=[chatbot, chat_input],
|
| 444 |
+
).then(lambda: gr.ClearButton(visible=True), None, [clear_button])
|
| 445 |
|
| 446 |
+
clear_button.click(
|
| 447 |
+
lambda: gr.ClearButton(visible=False), None, [clear_button]
|
| 448 |
+
)
|
| 449 |
|
| 450 |
with gr.Tab("Sample Library"):
|
| 451 |
+
gr.Markdown("## Sample Library\n\nExplore example audio files below.")
|
| 452 |
gr.Examples(
|
| 453 |
list(examples.values()),
|
| 454 |
chatbot,
|
|
|
|
| 457 |
examples_per_page=20,
|
| 458 |
)
|
| 459 |
with gr.Tab("💡 Help"):
|
| 460 |
+
gr.Markdown("## User Guide") # to fill out
|
| 461 |
+
gr.Markdown("## Share Feedback") # to fill out
|
| 462 |
+
gr.Markdown("## FAQs") # to fill out
|
| 463 |
+
|
| 464 |
app.css = """
|
| 465 |
.welcome-banner {
|
| 466 |
background: transparent !important;
|
|
|
|
| 484 |
_batch_tab()
|
| 485 |
with gr.Tab("Long Recording"):
|
| 486 |
_long_recording_tab() """
|
| 487 |
+
|
| 488 |
return app
|
| 489 |
|
| 490 |
|
|
|
|
| 493 |
assets_dir=Path("assets"),
|
| 494 |
cfg_path=Path("configs/inference.yml"),
|
| 495 |
options=[],
|
|
|
|
| 496 |
)
|
| 497 |
|
| 498 |
if __name__ == "__main__":
|
| 499 |
+
app.launch()
|
requirements.txt
CHANGED
|
@@ -1,31 +1,19 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
pydub>=0.25.1
|
| 21 |
-
pyyaml>=6.0
|
| 22 |
-
resampy>=0.3.1
|
| 23 |
-
scipy>=1.14.0
|
| 24 |
-
soundfile>=0.12.1
|
| 25 |
-
tensorboard>=2.18.0
|
| 26 |
-
tensorboardX>=2.6.2.2
|
| 27 |
-
spaces>=0.39.0
|
| 28 |
-
tqdm>=4.66.4
|
| 29 |
-
wandb>=0.17.3
|
| 30 |
-
click>=8.1.7
|
| 31 |
-
git+https://github.com/earthspecies/beans-zero.git
|
|
|
|
| 1 |
+
click>=8.2.1
|
| 2 |
+
einops>=0.8.1
|
| 3 |
+
gradio>=5.42.0
|
| 4 |
+
librosa>=0.11.0
|
| 5 |
+
pandas>=2.3.1
|
| 6 |
+
peft>=0.17.0
|
| 7 |
+
plumbum>=1.9.0
|
| 8 |
+
pydantic>=2.11.7
|
| 9 |
+
pydantic-settings>=2.10.1
|
| 10 |
+
pyyaml>=6.0.2
|
| 11 |
+
resampy>=0.4.3
|
| 12 |
+
scipy>=1.15.3
|
| 13 |
+
soundfile>=0.13.1
|
| 14 |
+
spaces>=0.40.0
|
| 15 |
+
torch>=2.8.0
|
| 16 |
+
torchaudio>=2.8.0
|
| 17 |
+
tqdm>=4.67.1
|
| 18 |
+
transformers[sentencepiece]>=4.55.2
|
| 19 |
+
matplotlib>=3.10.5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|