gagannarula commited on
Commit
32d3fde
·
verified ·
1 Parent(s): 1fe07a9

App-redesign (#1)

Browse files

- New app design, removing unnecessary code (5c7c0e25f77218ec5162453df6d15fcd71550f82)
- Merge branch 'gagan/redo-design' into pr/1 (1b7da9f7b82d5953b20899553e700fee83290a79)
- Small fix in launch (8ff0489f2c0201c4233e566b85acbd78ccfc632c)

NatureLM/augmentations.py DELETED
@@ -1,349 +0,0 @@
1
- import logging
2
- import random
3
-
4
- import numpy as np
5
- import torch as th
6
- from torch import nn
7
- from torch.nn import functional as F
8
-
9
- from NatureLM.utils import mel_frequencies
10
-
11
- logger = logging.getLogger(__name__)
12
-
13
-
14
- class RevEcho(nn.Module):
15
- """
16
- Hacky Reverb but runs on GPU without slowing down training. This reverb adds a
17
- succession of attenuated echos of the input signal to itself. Intuitively, the delay
18
- of the first echo will happen after roughly 2x the radius of the room and is
19
- controlled by `first_delay`. Then RevEcho keeps adding echos with the same delay and
20
- further attenuation until the amplitude ratio between the last and first echo is
21
- 1e-3. The attenuation factor and the number of echos to adds is controlled by RT60
22
- (measured in seconds). RT60 is the average time to get to -60dB (n.b. volume is
23
- measured over the squared amplitude so this matches the 1e-3 ratio).
24
-
25
- At each call to RevEcho, `first_delay`, `initial` and `RT60` are sampled from their
26
- range. Then, to prevent this reverb from being too regular, the delay time is
27
- resampled uniformly within `first_delay +/- 10%`, as controlled by the `jitter`
28
- parameter.
29
-
30
- Finally, for a denser reverb, multiple trains of echos are added with different
31
- jitter noises.
32
-
33
- Args:
34
- - initial: amplitude of the first echo as a fraction of the input signal. For
35
- each sample, actually sampled from `[0, initial]`. Larger values means louder
36
- reverb. Physically, this would depend on the absorption of the room walls.
37
- - rt60: range of values to sample the RT60 in seconds, i.e. after RT60 seconds,
38
- the echo amplitude is 1e-3 of the first echo. The default values follow the
39
- recommendations of https://arxiv.org/ftp/arxiv/papers/2001/2001.08662.pdf,
40
- Section 2.4. Physically this would also be related to the absorption of the
41
- room walls and there is likely a relation between `RT60` and `initial`, which
42
- we ignore here.
43
- - first_delay: range of values to sample the first echo delay in seconds. The
44
- default values are equivalent to sampling a room of 3 to 10 meters.
45
- - repeat: how many train of echos with differents jitters to add. Higher values
46
- means a denser reverb.
47
- - jitter: jitter used to make each repetition of the reverb echo train slightly
48
- different. For instance a jitter of 0.1 means the delay between two echos will
49
- be in the range `first_delay +- 10%`, with the jittering noise being resampled
50
- after each single echo.
51
- - keep_clean: fraction of the reverb of the clean speech to add back to the
52
- ground truth. 0 = dereverberation, 1 = no dereverberation.
53
- - sample_rate: sample rate of the input signals.
54
- """
55
-
56
- def __init__(
57
- self,
58
- proba=0.5,
59
- initial=0.3,
60
- rt60=(0.3, 1.3),
61
- first_delay=(0.01, 0.03),
62
- repeat=3,
63
- jitter=0.1,
64
- keep_clean=0.1,
65
- sample_rate=16000,
66
- rng=None,
67
- seed=42,
68
- ):
69
- super().__init__()
70
-
71
- self.proba = proba
72
- self.initial = initial
73
- self.rt60 = rt60
74
- self.first_delay = first_delay
75
- self.repeat = repeat
76
- self.jitter = jitter
77
- self.keep_clean = keep_clean
78
- self.sample_rate = sample_rate
79
- self.seed = seed
80
- self.rng = rng if rng is not None else random.Random(self.seed)
81
-
82
- def _reverb(self, source, initial, first_delay, rt60):
83
- """
84
- Return the reverb for a single source.
85
- """
86
- length = source.shape[-1]
87
- reverb = th.zeros_like(source)
88
-
89
- for _ in range(self.repeat):
90
- frac = 1 # what fraction of the first echo amplitude is still here
91
- echo = initial * source
92
- while frac > 1e-3:
93
- # First jitter noise for the delay
94
- jitter = 1 + self.jitter * self.rng.uniform(-1, 1)
95
- delay = min(1 + int(jitter * first_delay * self.sample_rate), length)
96
-
97
- # Delay the echo in time by padding with zero on the left
98
- echo = F.pad(echo[:, :, :-delay], (delay, 0))
99
- reverb += echo
100
-
101
- # Second jitter noise for the attenuation
102
- jitter = 1 + self.jitter * self.rng.uniform(-1, 1)
103
- # we want, with `d` the attenuation, d**(rt60 / first_ms) = 1e-3
104
- # i.e. log10(d) = -3 * first_ms / rt60, so that
105
- attenuation = 10 ** (-3 * jitter * first_delay / rt60)
106
- echo *= attenuation
107
- frac *= attenuation
108
-
109
- return reverb
110
-
111
- def forward(self, samples):
112
- if self.rng.random() >= self.proba:
113
- return samples
114
-
115
- raw_wav = samples.get("raw_wav", None)
116
-
117
- # add channel dimension if not exist
118
- if raw_wav.dim() == 2:
119
- raw_wav = raw_wav.unsqueeze(1)
120
-
121
- # Sample characteristics for the reverb
122
- initial = self.rng.random() * self.initial
123
- first_delay = self.rng.uniform(*self.first_delay)
124
- rt60 = self.rng.uniform(*self.rt60)
125
-
126
- reverb_wav = self._reverb(raw_wav, initial, first_delay, rt60)
127
- raw_wav += self.keep_clean * reverb_wav
128
-
129
- # remove channel dimension
130
- if raw_wav.dim() == 3 and raw_wav.shape[1] == 1:
131
- raw_wav = raw_wav.squeeze(1)
132
-
133
- samples["raw_wav"] = raw_wav
134
- return samples
135
-
136
-
137
- class BandMask(nn.Module):
138
- """
139
- Maskes bands of frequencies. Similar to Park, Daniel S., et al.
140
- "Specaugment: A simple data augmentation method for automatic speech recognition."
141
- (https://arxiv.org/pdf/1904.08779.pdf) but over the waveform.
142
- """
143
-
144
- def __init__(self, maxwidth=0.2, bands=120, sample_rate=16_000, rng=None, seed=42):
145
- """__init__.
146
-
147
- :param maxwidth: the maximum width to remove
148
- :param bands: number of bands
149
- :param sample_rate: signal sample rate
150
- """
151
- super().__init__()
152
- self.maxwidth = maxwidth
153
- self.bands = bands
154
- self.sample_rate = sample_rate
155
- self.seed = seed
156
- self.rng = rng if rng is not None else random.Random(self.seed)
157
-
158
- def forward(self, samples):
159
- raw_wav = samples.get("raw_wav", None)
160
-
161
- # add channel dimension if not exist
162
- if raw_wav.dim() == 2:
163
- raw_wav = raw_wav.unsqueeze(1)
164
-
165
- bands = self.bands
166
- bandwidth = int(abs(self.maxwidth) * bands)
167
- mels = mel_frequencies(bands, 40, self.sample_rate / 2) / self.sample_rate
168
- low = self.rng.randrange(bands)
169
- high = self.rng.randrange(low, min(bands, low + bandwidth))
170
-
171
- filters = LowPassFilters([mels[low], mels[high]]).to(raw_wav.device)
172
-
173
- low, midlow = filters(raw_wav)
174
- # band pass filtering
175
- out = raw_wav - midlow + low
176
-
177
- # remove channel dimension
178
- if out.dim() == 3 and out.shape[1] == 1:
179
- out = out.squeeze(1)
180
-
181
- samples["raw_wav"] = out
182
- return samples
183
-
184
-
185
- class Shift(nn.Module):
186
- def __init__(self, shift=8192, same=False, rngth=None):
187
- """
188
- :param shift: randomly shifts the signals up to a given factor
189
- :param same: shifts both clean and noisy files by the same factor
190
- """
191
- super().__init__()
192
- self.shift = shift
193
- self.same = same
194
- self.rngth = rngth
195
-
196
- def forward(self, samples):
197
- raw_wav = samples.get("raw_wav", None)
198
- batch, channels, length = raw_wav.shape
199
- length = length - self.shift
200
- if self.shift > 0:
201
- offsets = th.randint(
202
- self.shift, [1 if self.same else batch, 1, 1], device=raw_wav.device, generator=self.rngth
203
- )
204
- offsets = offsets.expand(-1, channels, -1)
205
- indexes = th.arange(length, device=raw_wav.device)
206
- import pdb
207
-
208
- pdb.set_trace()
209
- raw_wav = raw_wav.gather(2, indexes + offsets)
210
- samples["raw_wav"] = raw_wav
211
- return samples
212
-
213
-
214
- class TimeScale(nn.Module):
215
- """Fast time scale."""
216
-
217
- def __init__(self, scale=2.0, target=1, rngnp=None, seed=42):
218
- """
219
- :param scale: randomly scales up to this maximum factor
220
- """
221
- super().__init__()
222
- self.scale = scale
223
- self.target = target
224
- self.seed = seed
225
- self.rngnp = rngnp if rngnp is not None else np.random.default_rng(seed=self.seed)
226
-
227
- def forward(self, samples):
228
- try:
229
- raw_wav = samples.get("raw_wav")
230
- except KeyError:
231
- logger.error("Missing required key 'raw_wav' in samples dict")
232
- raise
233
-
234
- if "padding_mask" in samples:
235
- masks = samples.get("padding_mask")
236
- else:
237
- masks = th.ones_like(raw_wav)
238
-
239
- # add channel dimension if not exist
240
- if raw_wav.dim() == 2:
241
- raw_wav = raw_wav.unsqueeze(1)
242
- masks = masks.unsqueeze(1)
243
-
244
- # what to augment: noise, clean, or both
245
- if self.target == -1:
246
- targets = [i for i in range(raw_wav.shape[0])]
247
- else:
248
- targets = [self.target]
249
-
250
- for t in targets:
251
- signal = raw_wav[t]
252
- scaling = np.power(self.scale, self.rngnp.uniform(-1, 1))
253
- output_size = int(signal.shape[-1] * scaling)
254
- ref = th.arange(output_size, device=signal.device, dtype=signal.dtype).div_(scaling)
255
-
256
- ref1 = ref.clone().type(th.int64)
257
- ref2 = th.min(ref1 + 1, th.full_like(ref1, signal.shape[-1] - 1, dtype=th.int64))
258
- r = ref - ref1.type(ref.type())
259
- scaled_signal = signal[..., ref1] * (1 - r) + signal[..., ref2] * r
260
- scaled_masks = masks[t][..., ref1] * (1 - r) + masks[t][..., ref2] * r
261
-
262
- # trim or zero pad to the original size
263
- if scaled_signal.shape[-1] > signal.shape[-1]:
264
- nframes_offset = (scaled_signal.shape[-1] - signal.shape[-1]) // 2
265
- scaled_signal = scaled_signal[..., nframes_offset : nframes_offset + signal.shape[-1]]
266
- scaled_masks = scaled_masks[..., nframes_offset : nframes_offset + signal.shape[-1]]
267
- else:
268
- nframes_diff = signal.shape[-1] - scaled_signal.shape[-1]
269
- pad_left = int(np.random.uniform() * nframes_diff)
270
- pad_right = nframes_diff - pad_left
271
- scaled_signal = F.pad(
272
- input=scaled_signal, pad=(pad_left, pad_right, 0, 0, 0, 0), mode="constant", value=0
273
- )
274
- scaled_masks = F.pad(
275
- input=scaled_masks, pad=(pad_left, pad_right, 0, 0, 0, 0), mode="constant", value=0
276
- )
277
- raw_wav[t] = scaled_signal
278
- masks[t] = scaled_masks
279
-
280
- # remove channel dimension
281
- if raw_wav.dim() == 3 and raw_wav.shape[1] == 1:
282
- raw_wav = raw_wav.squeeze(1)
283
- masks = masks.squeeze(1)
284
-
285
- samples["raw_wav"] = raw_wav
286
- samples["padding_mask"] = masks
287
-
288
- return samples
289
-
290
-
291
- class Flip(nn.Module):
292
- def __init__(self, p=0.0, rngth=None):
293
- super(Flip, self).__init__()
294
-
295
- self.p = p
296
- self.rngth = rngth
297
-
298
- def forward(self, samples):
299
- raw_wav = samples["raw_wav"]
300
- if raw_wav.dim() > 2:
301
- flip_mask = th.rand(raw_wav.shape[0], device=raw_wav.device, generator=self.rngth) <= self.p
302
- raw_wav[flip_mask] = raw_wav[flip_mask].flip(-1)
303
- else:
304
- if th.rand(1, generator=self.rngth) <= self.p:
305
- raw_wav = raw_wav.flip(0)
306
- samples["raw_wav"] = raw_wav
307
- return samples
308
-
309
-
310
- class LowPassFilters(th.nn.Module):
311
- """
312
- Bank of low pass filters.
313
-
314
- Args:
315
- cutoffs (list[float]): list of cutoff frequencies, in [0, 1] expressed as `f/f_s` where
316
- f_s is the samplerate.
317
- width (int | None): width of the filters (i.e. kernel_size=2 * width + 1).
318
- Default to `2 / min(cutoffs)`. Longer filters will have better attenuation
319
- but more side effects.
320
- Shape:
321
- - Input: `(*, T)`
322
- - Output: `(F, *, T` with `F` the len of `cutoffs`.
323
- """
324
-
325
- def __init__(self, cutoffs: list, width: int | None = None):
326
- super().__init__()
327
-
328
- self.cutoffs = cutoffs
329
-
330
- if not width:
331
- width = int(2 / min(cutoffs))
332
- self.width = width
333
-
334
- window = th.hamming_window(2 * width + 1, periodic=False)
335
- t = np.arange(-width, width + 1, dtype=np.float32)
336
- filters = []
337
- for cutoff in cutoffs:
338
- sinc = th.from_numpy(np.sinc(2 * cutoff * t))
339
- filters.append(2 * cutoff * sinc * window)
340
- self.register_buffer("filters", th.stack(filters).unsqueeze(1))
341
-
342
- def forward(self, input):
343
- *others, t = input.shape
344
- input = input.view(-1, 1, t)
345
- out = F.conv1d(input, self.filters, padding=self.width)
346
- return out.permute(1, 0, 2).reshape(-1, *others, t)
347
-
348
- def __repr__(self):
349
- return "LossPassFilters(width={},cutoffs={})".format(self.width, self.cutoffs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
NatureLM/checkpoint_utils.py CHANGED
@@ -42,27 +42,6 @@ def get_state_dict(model, drop_untrained_params: bool = True) -> dict[str, Any]:
42
  return state_dict
43
 
44
 
45
- def torch_save_to_bucket(save_obj: Any, save_path: Union[str, os.PathLike], compress: bool = True) -> None:
46
- """Save an object directly to GCS bucket without intermediate disk storage.
47
-
48
- Args:
49
- save_obj: Object to save (usually model state dict or checkpoint)
50
- save_path: Path to save in GCS bucket (must be gs:// path)
51
- compress: Whether to use compression. Default: True
52
- """
53
- if not is_gcs_path(save_path):
54
- raise ValueError("save_path must be a GCS path")
55
-
56
- # save to a temporary local file and then upload to GCS
57
- with tempfile.NamedTemporaryFile() as tmp:
58
- torch.save(save_obj, tmp.name, _use_new_zipfile_serialization=compress)
59
- try:
60
- save_path.upload_from(tmp.name)
61
- except Exception as e:
62
- logger.error(f"Error saving to GCP bucket: {e}")
63
- raise e
64
-
65
-
66
  def save_model_checkpoint(
67
  model: nn.Module,
68
  save_path: Union[str, os.PathLike],
@@ -82,7 +61,7 @@ def save_model_checkpoint(
82
  extention (str): Extension to use for the checkpoint file. Default: "pth".
83
  **objects_to_save: Additional objects to save, e.g. optimizer state dict, etc.
84
  """
85
- if not is_gcs_path(save_path) and not os.path.exists(os.path.dirname(save_path)):
86
  raise FileNotFoundError(f"Directory {os.path.dirname(save_path)} does not exist.")
87
 
88
  model_no_ddp = maybe_unwrap_dist_model(model, use_distributed)
@@ -93,8 +72,4 @@ def save_model_checkpoint(
93
  }
94
 
95
  logger.info("Saving checkpoint to {}.".format(save_path))
96
-
97
- if is_gcs_path(save_path):
98
- torch_save_to_bucket(save_obj, save_path)
99
- else:
100
- torch.save(save_obj, save_path)
 
42
  return state_dict
43
 
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  def save_model_checkpoint(
46
  model: nn.Module,
47
  save_path: Union[str, os.PathLike],
 
61
  extention (str): Extension to use for the checkpoint file. Default: "pth".
62
  **objects_to_save: Additional objects to save, e.g. optimizer state dict, etc.
63
  """
64
+ if not os.path.exists(os.path.dirname(save_path)):
65
  raise FileNotFoundError(f"Directory {os.path.dirname(save_path)} does not exist.")
66
 
67
  model_no_ddp = maybe_unwrap_dist_model(model, use_distributed)
 
72
  }
73
 
74
  logger.info("Saving checkpoint to {}.".format(save_path))
75
+ torch.save(save_obj, save_path)
 
 
 
 
NatureLM/dataset.py DELETED
@@ -1,550 +0,0 @@
1
- # Copyright (2024) Earth Species Project
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
-
16
- """
17
- Mixing examples.
18
- Can mix:
19
- - base: options-detection add: open-ended:
20
- Take all open-ended labels. Add them to the options. Add them to the labels.
21
- - base: open-ended, add: open-ended
22
- Concatenate labels
23
- """
24
-
25
- import glob
26
- import json
27
- import os
28
- import random
29
- from collections import defaultdict
30
- from pathlib import Path
31
- from typing import Literal
32
-
33
- import numpy as np
34
- import soundfile as sf
35
- import torch
36
- from torch.nn.utils.rnn import pad_sequence
37
- from torch.utils.data import Dataset
38
-
39
- from NatureLM.utils import snr_scale, time_scale
40
-
41
-
42
- def write_example_to_file(base_filename, audio, sr=16000, suffix="_output", save_dir="debug_outputs"):
43
- """
44
- Writes the audio tensor to a file for debugging or inspection purposes.
45
-
46
- Args:
47
- base_filename (str): The base name of the original file.
48
- audio (torch.Tensor or numpy.ndarray): The audio waveform to save.
49
- sr (int): Sampling rate of the audio (default: 16000 Hz).
50
- suffix (str): Optional suffix to append to the filename.
51
- save_dir (str): Directory where the files will be saved.
52
- """
53
- if isinstance(audio, torch.Tensor):
54
- audio = audio.numpy() # Convert to numpy if necessary
55
-
56
- # Ensure the save directory exists
57
- os.makedirs(save_dir, exist_ok=True)
58
-
59
- # Create the output file path
60
- filename = f"{os.path.splitext(base_filename)[0]}{suffix}.wav"
61
- output_path = os.path.join(save_dir, filename)
62
-
63
- try:
64
- # Write the audio to the file
65
- sf.write(output_path, audio, sr)
66
- print(f"Saved audio to {output_path}")
67
- except Exception as e:
68
- print(f"Failed to write audio to file: {e}")
69
-
70
-
71
- # Example usage in your code
72
- # write_example_to_file(os.path.basename(ann["path"]), audio, suffix="_ts")
73
-
74
-
75
- def collater(samples):
76
- """Collate samples into a batch.
77
-
78
- Samples is a list of dictionaries, each containing the following keys:
79
- - raw_wav: a list of tensors containing the raw audio waveform
80
- - text: a list of strings containing the text
81
- - task: a list of strings containing the task
82
- - id: a list of strings containing the id
83
- - prompt: a list of strings containing the prompt
84
- - index: a list of integers containing the index
85
-
86
- The indiviudal audio waveforms will be stacked along the batch dimension for easier
87
- processing in the audio model. To keep which audio belongs to which sample, we add
88
- the audio_chunk_sizes key to the batch dictionary.
89
- """
90
- flat_raw_wav = []
91
- audio_chunk_sizes = []
92
-
93
- for s in samples:
94
- chunk_size = len(s["raw_wav"])
95
- audio_chunk_sizes.append(chunk_size)
96
- flat_raw_wav.extend(s["raw_wav"])
97
- # raw_wav = [torch.from_numpy(a) for a in flat_raw_wav]
98
- raw_wav = flat_raw_wav
99
- raw_wav_length = torch.tensor([len(a) for a in raw_wav])
100
- raw_wav = pad_sequence(raw_wav, batch_first=True, padding_value=0)
101
- paddding_mask = torch.arange(raw_wav.size(1)).unsqueeze(0) >= raw_wav_length.unsqueeze(1)
102
-
103
- text = [s["text"] for s in samples]
104
- prompt = [s["prompt"] for s in samples]
105
- task = [s["task"] for s in samples]
106
- id = [s["id"] for s in samples]
107
- index = [s["index"] for s in samples]
108
-
109
- return {
110
- "raw_wav": raw_wav,
111
- "padding_mask": paddding_mask,
112
- "text": text,
113
- "task": task,
114
- "id": id,
115
- "prompt": prompt,
116
- "index": index,
117
- "audio_chunk_sizes": audio_chunk_sizes,
118
- }
119
-
120
-
121
- class NatureLMDataset(Dataset):
122
- def __init__(
123
- self,
124
- ann_path: str | Path,
125
- *,
126
- max_length_seconds: int = 10,
127
- cropping: Literal["random", "start"] | None = "random",
128
- noise_prob: float = 0.0,
129
- noise_dirs: list[str] | list[Path] | None = None,
130
- low_snr: float = -5,
131
- high_snr: float = 20,
132
- time_scale_prob: float = 0.0,
133
- time_scale: float = 1.2,
134
- seed: int = 0,
135
- mixup_prob: float = 0.0,
136
- mixup_count: int = 3,
137
- use_augmentation: bool = False,
138
- mask_audio_prob: float = 0.0,
139
- ):
140
- super().__init__()
141
-
142
- ann_path = Path(ann_path)
143
-
144
- if not ann_path.exists():
145
- raise FileNotFoundError(f"Dataset file {ann_path} not found")
146
-
147
- try:
148
- with open(ann_path, "r") as f:
149
- data = json.load(f)
150
- self.annotation = data["annotation"]
151
- except (json.JSONDecodeError, KeyError):
152
- with open(ann_path, "r") as f:
153
- self.annotation = [json.loads(line) for line in f]
154
-
155
- #### mixup related variables
156
- ### hash table for tasks to sample the tasks faster
157
- self.tasks = defaultdict(list)
158
- for i, ann in enumerate(self.annotation):
159
- if "task" in ann and "text" in ann and ann["text"] != "None" and "path" in ann:
160
- self.tasks[ann["task"]].append(i)
161
-
162
- self.mixup_tasks = {
163
- task: []
164
- for task in self.tasks.keys()
165
- if task.endswith("simple-detection")
166
- or task.endswith("multiple-detection") # Add more tasks after validating prompt mixing.
167
- or task.endswith("sci-detection-random")
168
- or task.endswith("common-detection-random")
169
- }
170
- for k in self.mixup_tasks.keys():
171
- # whichever the base, only mix open-ended tasks.
172
- if "sci-" in k:
173
- self.mixup_tasks[k] = [
174
- task
175
- for task in self.mixup_tasks.keys()
176
- if task.endswith("sci-simple-detection") or task.endswith("sci-multiple-detection")
177
- ]
178
- elif "common-" in k:
179
- self.mixup_tasks[k] = [
180
- task
181
- for task in self.mixup_tasks.keys()
182
- if task.endswith("common-simple-detection") or task.endswith("common-multiple-detection")
183
- ]
184
- else:
185
- self.mixup_tasks[k] = [task for task in self.mixup_tasks.keys() if "common-" in task]
186
-
187
- # print("num annotations", len(self.annotation))
188
- # print("annotation 0", self.annotation[0])
189
- # self.annotation = [a for a in self.annotation if "task" in a and "detection" not in a["task"]] # no detection... :(
190
- self.max_length_seconds = max_length_seconds
191
- self.cropping = cropping
192
- self.use_augmentation = use_augmentation
193
-
194
- ### noise augmentation
195
- self.rng = random.Random(seed)
196
- self.rngnp = np.random.default_rng(seed=seed)
197
- self.noise_dirs = noise_dirs
198
- self.noise_prob = noise_prob
199
- self.noise_files = []
200
- self.low_snr = low_snr
201
- self.high_snr = high_snr
202
- self.mask_audio_prob = mask_audio_prob
203
- if noise_dirs is not None and len(self.noise_dirs) > 0 and self.use_augmentation:
204
- for noise_dir in noise_dirs:
205
- noise_from_dir = glob.glob(os.path.join(noise_dir, "*.wav"))
206
- if len(noise_from_dir) < 3000:
207
- noise_from_dir = noise_from_dir * 3
208
- print("noise files from dir", noise_dir, len(noise_from_dir))
209
- self.noise_files.extend(noise_from_dir)
210
-
211
- ### mixup augmentation
212
- self.mixup_prob = mixup_prob
213
- self.mixup_count = mixup_count
214
- # ### time scale augmentation
215
- self.time_scale = time_scale
216
- self.time_scale_prob = time_scale_prob
217
- # tasks = set([annotation["task"] if "task" in annotation else "empty" for annotation in self.annotation])
218
- print(":::all tasks:::", self.tasks.keys())
219
- print("num examples", len(self.annotation))
220
-
221
- def __len__(self):
222
- return len(self.annotation)
223
-
224
- def collater(self, samples):
225
- return collater(samples)
226
-
227
- def load_audio(self, audio_path, shift_allowed: bool, noise_allowed: bool):
228
- audio, sr = sf.read(audio_path)
229
- # assert sr == 16000
230
- if sr != 16000:
231
- print("other sr!", sr, audio_path)
232
- if len(audio.shape) == 2: # stereo to mono
233
- audio = audio.mean(axis=1)
234
-
235
- ### time scale augmentation
236
- if self.use_augmentation and self.rng.random() < self.time_scale_prob and self.time_scale > 0 and shift_allowed:
237
- # print(f"{index} scaling audio")
238
- # write_example_to_file(os.path.basename(ann["path"]), audio[: sr * self.max_length_seconds] )
239
- audio = time_scale(torch.tensor(audio), scale=self.time_scale, rngnp=self.rngnp).numpy()
240
- # write_example_to_file(os.path.basename(ann["path"]), audio[: sr * self.max_length_seconds] , suffix='_ts')
241
-
242
- # Randomly crop a max_length_seconds window if audio is longer than 10 seconds
243
- if len(audio) > sr * self.max_length_seconds and self.cropping == "random":
244
- max_start = len(audio) - sr * self.max_length_seconds
245
- start = random.randint(0, max_start)
246
- audio = audio[start : start + sr * self.max_length_seconds]
247
- else: # no random cropping
248
- audio = audio[: sr * self.max_length_seconds] # Truncate audio to at most max_length_seconds
249
-
250
- ### noise augmentation
251
- audio = torch.tensor(audio)
252
- ### noise augmentation
253
- if (
254
- self.use_augmentation
255
- and self.rng.random() < self.noise_prob
256
- and len(self.noise_files) > 0
257
- and noise_allowed
258
- ):
259
- # write_example_to_file(os.path.basename(ann["path"]), audio)
260
- # print(f"{index} adding noise")
261
- noise_file = self.rng.choice(self.noise_files)
262
- if not os.path.exists(noise_file):
263
- print(f"Warning: noise file {noise_file} does not exist")
264
- else:
265
- noise_audio, noise_sr = sf.read(noise_file)
266
- assert noise_sr == 16000
267
- if len(noise_audio.shape) == 2:
268
- noise_audio = noise_audio.mean(axis=1)
269
-
270
- noise_audio = torch.tensor(noise_audio)
271
-
272
- ### repeat or trim to the audio size
273
- if len(audio) > len(noise_audio):
274
- if len(noise_audio) == 0:
275
- print(
276
- "----- Warning: Noise audio length is zero. ---------- ",
277
- noise_file,
278
- )
279
- # Option 1: Skip noise augmentation by setting noise_audio to zero
280
- noise_audio = torch.zeros_like(audio)
281
- else:
282
- nrepeats = int(np.maximum(2, np.ceil(len(audio) / len(noise_audio))))
283
- noise_audio = noise_audio.repeat(nrepeats)
284
- ### Randomly crop the noise file if it is too long
285
- if len(noise_audio) > len(audio):
286
- max_start = len(noise_audio) - len(audio)
287
- start = random.randint(0, max_start)
288
- noise_audio = noise_audio[start : start + len(audio)]
289
-
290
- ### remix with specified snr
291
- snr = self.rngnp.uniform(self.low_snr, self.high_snr)
292
- snr = torch.tensor([snr])
293
- noise_audio = snr_scale(audio, noise_audio, snr)
294
- audio = audio + noise_audio
295
-
296
- # write_example_to_file(os.path.basename(audio_path), audio, suffix='_noise')
297
- if len(audio) > self.max_length_seconds * sr:
298
- print("long audio", len(audio), len(noise_audio))
299
- audio = audio[: self.max_length_seconds * sr]
300
-
301
- # pad all audios to max_len_seconds in _getitem_ to ensure no padding inconsistencies.
302
- if len(audio) < sr * self.max_length_seconds:
303
- pad_size = sr * self.max_length_seconds - len(audio)
304
- audio = torch.nn.functional.pad(audio, (0, pad_size))
305
-
306
- audio = torch.clamp(audio, -1.0, 1.0)
307
-
308
- return audio
309
-
310
- def _mix_labels(self, text, text_to_mix):
311
- """
312
- Given two comma-separated label strings (e.g., "gorilla, zebra"),
313
- combine them without introducing duplicates. If either is "None",
314
- return the other as-is (unless both are "None").
315
- """
316
- # If `text_to_mix` is explicitly "None", just return `text`.
317
- if text_to_mix == "None":
318
- return text
319
-
320
- # If `text` is explicitly "None", just return `text_to_mix`.
321
- if text == "None":
322
- return text_to_mix
323
-
324
- # Split both strings by comma, stripping whitespace
325
- text_list = [item.strip() for item in text.split(",") if item.strip()]
326
- text_to_mix_list = [item.strip() for item in text_to_mix.split(",") if item.strip()]
327
-
328
- # Deduplicate: add only new items from text_to_mix_list
329
- combined_set = set(text_list)
330
- for item in text_to_mix_list:
331
- if item not in combined_set:
332
- text_list.append(item)
333
- combined_set.add(item)
334
-
335
- # If there's nothing left after deduplication, return "None".
336
- if not text_list:
337
- return "None"
338
-
339
- # Rejoin them into a comma-separated string
340
- return ", ".join(text_list)
341
-
342
- def _mix_prompts(self, text, text_to_mix, prompt):
343
- """
344
- If the prompt is in the form:
345
- "Which of these, if any, are present in the audio recording? option1, option2, ..."
346
-
347
- 1. Parse out the question (before '?') and the list of prompt choices (after '?').
348
- 2. Convert both `text` and `text_to_mix` into lists, checking for items not in the prompt.
349
- 3. Append any missing answers to the prompt choices.
350
- 4. Shuffle the choices.
351
- 5. Reassemble and return the new prompt.
352
-
353
- If the prompt does not follow the expected structure, it is returned unmodified.
354
- """
355
- # Split into two parts: question + choices
356
- splitted = prompt.split("?")
357
- if len(splitted) != 2:
358
- # If we don't have exactly one question mark segment, just return the original prompt
359
- return prompt
360
-
361
- question = splitted[0].strip()
362
- potential_choices_str = splitted[1].strip()
363
-
364
- # Split the prompt choices
365
- if not potential_choices_str:
366
- prompt_choices = []
367
- else:
368
- prompt_choices = [c.strip() for c in potential_choices_str.split(",") if c.strip()]
369
-
370
- # Parse `text`
371
- text_list = [item.strip() for item in text.split(",") if item.strip()]
372
-
373
- # Parse `text_to_mix`
374
- text_to_mix_list = [item.strip() for item in text_to_mix.split(",") if item.strip()]
375
-
376
- # Add any new items from text_list to the prompt
377
- for item in text_list:
378
- if item not in prompt_choices:
379
- prompt_choices.append(item)
380
-
381
- # Add any new items from text_to_mix_list to the prompt
382
- for item in text_to_mix_list:
383
- if item not in prompt_choices:
384
- prompt_choices.append(item)
385
-
386
- # Shuffle consistently with self.rng
387
- self.rng.shuffle(prompt_choices)
388
-
389
- # Reassemble
390
- new_prompt = question + "? " + ", ".join(prompt_choices)
391
- return new_prompt
392
-
393
- def _apply_mixup(self, prompt, audio, text, task, filename=None):
394
- # mixup_applied = False
395
- if (
396
- self.use_augmentation and self.rng.random() < self.mixup_prob and task in self.mixup_tasks
397
- # and text != "None" # Allow complex 'None' examples.
398
- ):
399
- # write_example_to_file(os.path.basename(ann["path"]), audio)
400
- # print(f"{index} mixing up")
401
- mixup_indices = []
402
- for pair_task in self.mixup_tasks[task]:
403
- mixup_indices.extend(self.tasks[pair_task])
404
- # mixup_indices = mixup_indices.remove(index)
405
-
406
- if len(mixup_indices) == 0:
407
- print("No mixup partner found")
408
- else:
409
- ### choose n_mixup random partners
410
- n_mixup = self.rng.randint(1, self.mixup_count)
411
- mixup_indices = self.rng.sample(mixup_indices, n_mixup)
412
- # print(f"Mixing up with indices {mixup_indices}")
413
- for mixup_index in mixup_indices:
414
- mixup_ann = self.annotation[mixup_index]
415
- mixup_audio, _ = sf.read(mixup_ann["path"])
416
- if len(mixup_audio.shape) == 2:
417
- mixup_audio = mixup_audio.mean(axis=1)
418
- mixup_audio = mixup_audio[: len(audio)]
419
- if len(mixup_audio) < len(audio):
420
- pad_size = len(audio) - len(mixup_audio)
421
- mixup_audio = np.pad(mixup_audio, (0, pad_size), mode="constant")
422
- mixup_audio = torch.from_numpy(mixup_audio).float()
423
- lam = np.clip(self.rngnp.beta(1.0, 1.0), 0.1, 0.8)
424
-
425
- # Mix the raw_wav
426
- audio = lam * audio + (1 - lam) * mixup_audio
427
-
428
- ### Mix the prompts if the labels are given in prompts
429
- if text in prompt:
430
- prompt = self._mix_prompts(text, mixup_ann["text"], prompt)
431
-
432
- ### Mix the labels
433
- text = self._mix_labels(text, mixup_ann["text"])
434
-
435
- # mixup_applied = True
436
-
437
- # DEBUG: If mixup was actually applied, save the final audio
438
- # if mixup_applied and filename is not None:
439
- # # Just add a suffix to the original filename to indicate mixup
440
- # base_filename = os.path.basename(filename)
441
- # write_example_to_file(
442
- # base_filename=base_filename,
443
- # audio=audio,
444
- # sr=16000,
445
- # suffix="_mixup",
446
- # save_dir="mixup_outputs"
447
- # )
448
- # print(f"mixup for {filename}::: prompt {prompt} label {text}")
449
-
450
- return prompt, audio, text
451
-
452
- def _load_noise(self, shift_allowed: bool):
453
- noise_file = self.rng.choice(self.noise_files)
454
- noise_audio, noise_sr = sf.read(noise_file)
455
- assert noise_sr == 16000, f"Expected noise sample rate 16000, got {noise_sr}"
456
- if len(noise_audio.shape) == 2:
457
- noise_audio = noise_audio.mean(axis=1)
458
-
459
- # Time scale augmentation if applicable
460
- if self.use_augmentation and self.rng.random() < self.time_scale_prob and self.time_scale > 0 and shift_allowed:
461
- noise_audio = time_scale(torch.tensor(noise_audio), scale=self.time_scale, rngnp=self.rngnp).numpy()
462
-
463
- # Randomly crop or pad to match max_length_seconds
464
- if len(noise_audio) > self.max_length_seconds * 16000 and self.cropping == "random":
465
- max_start = len(noise_audio) - self.max_length_seconds * 16000
466
- start = random.randint(0, max_start)
467
- noise_audio = noise_audio[start : start + self.max_length_seconds * 16000]
468
- else:
469
- noise_audio = noise_audio[: self.max_length_seconds * 16000]
470
-
471
- # Pad if needed
472
- if len(noise_audio) < self.max_length_seconds * 16000:
473
- pad_size = self.max_length_seconds * 16000 - len(noise_audio)
474
- noise_audio = np.pad(noise_audio, (0, pad_size), mode="constant")
475
-
476
- noise_audio = torch.tensor(noise_audio).float()
477
- noise_audio = torch.clamp(noise_audio, -1.0, 1.0)
478
- return noise_audio
479
-
480
- def __getitem__(self, index):
481
- ann = self.annotation[index]
482
- # print("loading audio::", ann)
483
- shift_allowed = "pitch" not in ann.get("task", "")
484
- noise_allowed = (
485
- "/A/" not in ann.get("path", "")
486
- and "-qa" not in ann.get("task", "")
487
- and "icl" not in ann.get("task", "")
488
- and "caption" not in ann.get("task", "")
489
- and "animal-instructions" not in ann.get("task", "")
490
- )
491
-
492
- task = ann.get("task", "asr")
493
- text = ann["text"]
494
- prompt = ann["prompt"]
495
-
496
- replace_with_noise = (
497
- self.use_augmentation
498
- and task.endswith("detection")
499
- and self.rng.random() < self.mask_audio_prob
500
- and len(self.noise_files) > 0
501
- )
502
-
503
- if replace_with_noise:
504
- # Replace audio with noise
505
- audio = self._load_noise(shift_allowed)
506
- audios = [audio]
507
- text = "None"
508
-
509
- else:
510
- if "path" in ann and ann["path"] is not None:
511
- audio = self.load_audio(ann["path"], shift_allowed, noise_allowed)
512
- audios = [audio]
513
- else:
514
- audios = [self.load_audio(p, shift_allowed, noise_allowed) for p in ann["files"]]
515
-
516
- if len(audios) == 1:
517
- prompt, mixed_audio, text = self._apply_mixup(prompt, audio, text, task, filename=ann["path"])
518
- audios = [mixed_audio]
519
-
520
- return {
521
- "raw_wav": audios,
522
- "text": text,
523
- "task": task,
524
- "id": ann.get("path") or ";".join(ann["files"]),
525
- "prompt": prompt,
526
- "index": index, # track which element for eval output
527
- "ann": ann, # Include annotation for mixup
528
- }
529
-
530
-
531
- if __name__ == "__main__":
532
- dataset = NatureLMDataset(
533
- ann_path="/home/ubuntu/foundation-model-storage/foundation-model-data/data/compiled-datasets/v1/s2_eval_valid.jsonl",
534
- noise_dirs=["resource/audio_demo"],
535
- max_length_seconds=10,
536
- use_augmentation=True,
537
- mixup_prob=1.0, # For demonstration, force mixup if possible
538
- mixup_count=2, # Up to 2 mixup partners
539
- mask_audio_prob=0.2,
540
- seed=42,
541
- noise_prob=0.5,
542
- )
543
-
544
- # Process just a few to see the saved mixups
545
- for i in range(300):
546
- sample = dataset[i]
547
- # print("Final text:", sample["text"])
548
- # print("Final prompt:", sample["prompt"])
549
- # print("-" * 40)
550
- print("Done! Look in 'debug_outputs' folder for saved mixup files.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
NatureLM/dist_utils.py DELETED
@@ -1,109 +0,0 @@
1
- """
2
- Adapted from salesforce@LAVIS. Below is the original copyright:
3
- Copyright (c) 2022, salesforce.com, inc.
4
- All rights reserved.
5
- SPDX-License-Identifier: BSD-3-Clause
6
- For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
7
- """
8
-
9
- import datetime
10
- import functools
11
- import os
12
-
13
- import torch
14
- import torch.distributed as dist
15
-
16
-
17
- def setup_for_distributed(is_master):
18
- """
19
- This function disables printing when not in master process
20
- """
21
- import builtins as __builtin__
22
-
23
- builtin_print = __builtin__.print
24
-
25
- def print(*args, **kwargs):
26
- force = kwargs.pop("force", False)
27
- if is_master or force:
28
- builtin_print(*args, **kwargs)
29
-
30
- __builtin__.print = print
31
-
32
-
33
- def is_dist_avail_and_initialized():
34
- if not dist.is_available():
35
- return False
36
- if not dist.is_initialized():
37
- return False
38
- return True
39
-
40
-
41
- def get_world_size():
42
- if not is_dist_avail_and_initialized():
43
- return 1
44
- return dist.get_world_size()
45
-
46
-
47
- def get_rank():
48
- if not is_dist_avail_and_initialized():
49
- return 0
50
- return dist.get_rank()
51
-
52
-
53
- def is_main_process():
54
- return get_rank() == 0
55
-
56
-
57
- def init_distributed_mode(args):
58
- if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
59
- args.rank = int(os.environ["RANK"])
60
- args.world_size = int(os.environ["WORLD_SIZE"])
61
- args.gpu = int(os.environ["LOCAL_RANK"])
62
- elif "SLURM_PROCID" in os.environ:
63
- args.rank = int(os.environ["SLURM_PROCID"])
64
- args.gpu = args.rank % torch.cuda.device_count()
65
- else:
66
- print("Not using distributed mode")
67
- args.use_distributed = False
68
- return
69
-
70
- args.use_distributed = True
71
-
72
- torch.cuda.set_device(args.gpu)
73
- print(
74
- "| distributed init (rank {}, world {}): {}".format(args.rank, args.world_size, args.dist_url),
75
- flush=True,
76
- )
77
- torch.distributed.init_process_group(
78
- backend=args.dist_backend,
79
- init_method=args.dist_url,
80
- world_size=args.world_size,
81
- rank=args.rank,
82
- timeout=datetime.timedelta(days=365), # allow auto-downloading and de-compressing
83
- )
84
- torch.distributed.barrier()
85
- setup_for_distributed(args.rank == 0)
86
-
87
-
88
- def get_dist_info():
89
- if torch.__version__ < "1.0":
90
- initialized = dist._initialized
91
- else:
92
- initialized = dist.is_initialized()
93
- if initialized:
94
- rank = dist.get_rank()
95
- world_size = dist.get_world_size()
96
- else: # non-distributed training
97
- rank = 0
98
- world_size = 1
99
- return rank, world_size
100
-
101
-
102
- def main_process(func):
103
- @functools.wraps(func)
104
- def wrapper(*args, **kwargs):
105
- rank, _ = get_dist_info()
106
- if rank == 0:
107
- return func(*args, **kwargs)
108
-
109
- return wrapper
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
NatureLM/infer.py CHANGED
@@ -5,7 +5,7 @@ from pathlib import Path
5
 
6
  import numpy as np
7
  import pandas as pd
8
- import soundfile as sf
9
  import torch
10
 
11
  from NatureLM.config import Config
@@ -16,10 +16,15 @@ from NatureLM.utils import move_to_device
16
  _MAX_LENGTH_SECONDS = 10
17
  _MIN_CHUNK_LENGTH_SECONDS = 0.5
18
  _SAMPLE_RATE = 16000 # Assuming the model uses a sample rate of 16kHz
19
- _AUDIO_FILE_EXTENSIONS = [".wav", ".mp3", ".flac", ".ogg"] # Add other audio file formats as needed
20
- _DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
21
- __this_dir = Path(__file__).parent.parent
22
- _DEFAULT_CONFIG_PATH = __this_dir / "configs" / "inference.yml"
 
 
 
 
 
23
 
24
 
25
  def load_model_and_config(
@@ -32,7 +37,9 @@ def load_model_and_config(
32
  model = NatureLM.from_pretrained("EarthSpeciesProject/NatureLM-audio")
33
  model = model.to(device).eval()
34
  model.llama_tokenizer.pad_token_id = model.llama_tokenizer.eos_token_id
35
- model.llama_model.generation_config.pad_token_id = model.llama_tokenizer.pad_token_id
 
 
36
 
37
  cfg = Config.from_sources(cfg_path)
38
  return model, cfg
@@ -53,7 +60,7 @@ def sliding_window_inference(
53
  hop_length_seconds: float = 10.0,
54
  input_sr: int = _SAMPLE_RATE,
55
  device: str = _DEVICE,
56
- ) -> str:
57
  """Run inference on a long audio file using sliding window approach.
58
 
59
  Args:
@@ -73,7 +80,7 @@ def sliding_window_inference(
73
  ValueError: If the audio file is too short or if the audio file path is invalid.
74
  """
75
  if isinstance(audio, str) or isinstance(audio, Path):
76
- audio_array, input_sr = sf.read(str(audio))
77
  elif isinstance(audio, np.ndarray):
78
  audio_array = audio
79
  print(f"Using provided sample rate: {input_sr}")
@@ -86,13 +93,16 @@ def sliding_window_inference(
86
 
87
  # Do initial check that the audio is long enough
88
  if audio_array.shape[-1] < int(_MIN_CHUNK_LENGTH_SECONDS * input_sr):
89
- raise ValueError(f"Audio is too short. Minimum length is {_MIN_CHUNK_LENGTH_SECONDS} seconds.")
 
 
90
 
91
  start = 0
92
  stride = int(hop_length_seconds * input_sr)
93
  window_length = int(window_length_seconds * input_sr)
 
94
 
95
- output = ""
96
  while True:
97
  chunk = audio_array[start : start + window_length]
98
  if chunk.shape[-1] < int(_MIN_CHUNK_LENGTH_SECONDS * input_sr):
@@ -113,8 +123,16 @@ def sliding_window_inference(
113
  prediction: str = model.generate(input_to_model, cfg.generate, prompt_list)[0]
114
 
115
  # Post-process the prediction
116
- prediction = output_template(prediction, start / input_sr, (start + window_length) / input_sr)
117
- output += prediction
 
 
 
 
 
 
 
 
118
 
119
  # Move the window
120
  start += stride
@@ -128,7 +146,9 @@ def sliding_window_inference(
128
  class Pipeline:
129
  """Pipeline for running NatureLM-audio inference on a list of audio files or audio arrays"""
130
 
131
- def __init__(self, model: NatureLM = None, cfg_path: str | Path = _DEFAULT_CONFIG_PATH):
 
 
132
  self.cfg_path = cfg_path
133
 
134
  # Load model and config
@@ -139,7 +159,9 @@ class Pipeline:
139
  # Download model from hub
140
  self.model, self.cfg = load_model_and_config(cfg_path)
141
 
142
- self.processor = NatureLMAudioProcessor(sample_rate=_SAMPLE_RATE, max_length_seconds=_MAX_LENGTH_SECONDS)
 
 
143
 
144
  def __call__(
145
  self,
@@ -149,6 +171,7 @@ class Pipeline:
149
  hop_length_seconds: float = 10.0,
150
  input_sample_rate: int = _SAMPLE_RATE,
151
  verbose: bool = False,
 
152
  ) -> list[str]:
153
  """Run inference on a list of audio file paths or a single audio file with a
154
  single query or a list of queries. If multiple queries are provided,
@@ -165,18 +188,11 @@ class Pipeline:
165
  Defaults to False.
166
 
167
  Returns:
168
- str | list[str]: The output of the model..
 
169
 
170
  Raises:
171
  ValueError: If the number of audio files and queries do not match.
172
-
173
- Example:
174
- >>> pipeline = Pipeline()
175
- >>> audios = ["assets/nri-GreenTreeFrogEvergladesNP.mp3"]
176
- >>> queries = ["Which species is this? Provide the common name."]
177
- >>> results = pipeline(audios, queries)
178
- >>> print(results)
179
- ['#0.00s - 10.00s#: Green Treefrog\n']
180
  """
181
  if isinstance(audios, str) or isinstance(audios, Path):
182
  audios = [audios]
@@ -189,7 +205,10 @@ class Pipeline:
189
 
190
  # Run inference
191
  results = []
192
- for audio, query in zip(audios, queries):
 
 
 
193
  output = sliding_window_inference(
194
  audio,
195
  query,
@@ -209,21 +228,38 @@ class Pipeline:
209
  def parse_args() -> argparse.Namespace:
210
  parser = argparse.ArgumentParser("Run NatureLM-audio inference")
211
  parser.add_argument(
212
- "-a", "--audio", type=str, required=True, help="Path to an audio file or a directory containing audio files"
 
 
 
 
 
 
 
213
  )
214
- parser.add_argument("-q", "--query", type=str, required=True, help="Query for the model")
215
  parser.add_argument(
216
  "--cfg-path",
217
  type=str,
218
  default="configs/inference.yml",
219
  help="Path to the configuration file for the model",
220
  )
221
- parser.add_argument("--output_path", type=str, default="inference_output.jsonl", help="Output path for the results")
222
  parser.add_argument(
223
- "--window_length_seconds", type=float, default=10.0, help="Length of the sliding window in seconds"
 
 
 
 
 
 
 
 
 
224
  )
225
  parser.add_argument(
226
- "--hop_length_seconds", type=float, default=10.0, help="Hop length for the sliding window in seconds"
 
 
 
227
  )
228
  args = parser.parse_args()
229
 
@@ -261,7 +297,9 @@ def main(
261
  audio_path = Path(audio_path)
262
  if audio_path.is_dir():
263
  audio_paths = []
264
- print(f"Searching for audio files in {str(audio_path)} with extensions {', '.join(_AUDIO_FILE_EXTENSIONS)}")
 
 
265
  for ext in _AUDIO_FILE_EXTENSIONS:
266
  audio_paths.extend(list(audio_path.rglob(f"*{ext}")))
267
 
@@ -278,18 +316,30 @@ def main(
278
  if not query:
279
  raise ValueError("Query cannot be empty")
280
  if not audio_paths:
281
- raise ValueError("No audio files found. Please check the path or file extensions.")
 
 
282
 
283
  # Load model and config
284
  model, cfg = load_model_and_config(cfg_path)
285
 
286
  # Load audio processor
287
- processor = NatureLMAudioProcessor(sample_rate=_SAMPLE_RATE, max_length_seconds=_MAX_LENGTH_SECONDS)
 
 
288
 
289
  # Run inference
290
  results = {"audio_path": [], "output": []}
291
  for path in audio_paths:
292
- output = sliding_window_inference(path, query, processor, model, cfg, window_length_seconds, hop_length_seconds)
 
 
 
 
 
 
 
 
293
  results["audio_path"].append(str(path))
294
  results["output"].append(output)
295
  print(f"Processed {path}, model output:\n=======\n{output}\n=======\n")
 
5
 
6
  import numpy as np
7
  import pandas as pd
8
+ import librosa
9
  import torch
10
 
11
  from NatureLM.config import Config
 
16
  _MAX_LENGTH_SECONDS = 10
17
  _MIN_CHUNK_LENGTH_SECONDS = 0.5
18
  _SAMPLE_RATE = 16000 # Assuming the model uses a sample rate of 16kHz
19
+ _AUDIO_FILE_EXTENSIONS = [
20
+ ".wav",
21
+ ".mp3",
22
+ ".flac",
23
+ ".ogg",
24
+ ] # Add other audio file formats as needed
25
+ _DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
26
+ __root_dir = Path(__file__).parent.parent
27
+ _DEFAULT_CONFIG_PATH = __root_dir / "configs" / "inference.yml"
28
 
29
 
30
  def load_model_and_config(
 
37
  model = NatureLM.from_pretrained("EarthSpeciesProject/NatureLM-audio")
38
  model = model.to(device).eval()
39
  model.llama_tokenizer.pad_token_id = model.llama_tokenizer.eos_token_id
40
+ model.llama_model.generation_config.pad_token_id = (
41
+ model.llama_tokenizer.pad_token_id
42
+ )
43
 
44
  cfg = Config.from_sources(cfg_path)
45
  return model, cfg
 
60
  hop_length_seconds: float = 10.0,
61
  input_sr: int = _SAMPLE_RATE,
62
  device: str = _DEVICE,
63
+ ) -> list[dict[str, any]]:
64
  """Run inference on a long audio file using sliding window approach.
65
 
66
  Args:
 
80
  ValueError: If the audio file is too short or if the audio file path is invalid.
81
  """
82
  if isinstance(audio, str) or isinstance(audio, Path):
83
+ audio_array, input_sr = librosa.load(str(audio), sr=None, mono=False)
84
  elif isinstance(audio, np.ndarray):
85
  audio_array = audio
86
  print(f"Using provided sample rate: {input_sr}")
 
93
 
94
  # Do initial check that the audio is long enough
95
  if audio_array.shape[-1] < int(_MIN_CHUNK_LENGTH_SECONDS * input_sr):
96
+ raise ValueError(
97
+ f"Audio is too short. Minimum length is {_MIN_CHUNK_LENGTH_SECONDS} seconds."
98
+ )
99
 
100
  start = 0
101
  stride = int(hop_length_seconds * input_sr)
102
  window_length = int(window_length_seconds * input_sr)
103
+ window_id = 0
104
 
105
+ output = [] # Initialize output list
106
  while True:
107
  chunk = audio_array[start : start + window_length]
108
  if chunk.shape[-1] < int(_MIN_CHUNK_LENGTH_SECONDS * input_sr):
 
123
  prediction: str = model.generate(input_to_model, cfg.generate, prompt_list)[0]
124
 
125
  # Post-process the prediction
126
+ # prediction = output_template(prediction, start / input_sr, (start + window_length) / input_sr)
127
+ # output += prediction
128
+ output.append(
129
+ {
130
+ "start_time": start / input_sr,
131
+ "end_time": (start + window_length) / input_sr,
132
+ "prediction": prediction,
133
+ "window_number": window_id,
134
+ }
135
+ )
136
 
137
  # Move the window
138
  start += stride
 
146
  class Pipeline:
147
  """Pipeline for running NatureLM-audio inference on a list of audio files or audio arrays"""
148
 
149
+ def __init__(
150
+ self, model: NatureLM = None, cfg_path: str | Path = _DEFAULT_CONFIG_PATH
151
+ ):
152
  self.cfg_path = cfg_path
153
 
154
  # Load model and config
 
159
  # Download model from hub
160
  self.model, self.cfg = load_model_and_config(cfg_path)
161
 
162
+ self.processor = NatureLMAudioProcessor(
163
+ sample_rate=_SAMPLE_RATE, max_length_seconds=_MAX_LENGTH_SECONDS
164
+ )
165
 
166
  def __call__(
167
  self,
 
171
  hop_length_seconds: float = 10.0,
172
  input_sample_rate: int = _SAMPLE_RATE,
173
  verbose: bool = False,
174
+ progress_bar=None,
175
  ) -> list[str]:
176
  """Run inference on a list of audio file paths or a single audio file with a
177
  single query or a list of queries. If multiple queries are provided,
 
188
  Defaults to False.
189
 
190
  Returns:
191
+ list[list[dict]]: List of model outputs for each audio file. Each output is a list of dictionaries
192
+ containing the start time, end time, and prediction for each chunk of audio.
193
 
194
  Raises:
195
  ValueError: If the number of audio files and queries do not match.
 
 
 
 
 
 
 
 
196
  """
197
  if isinstance(audios, str) or isinstance(audios, Path):
198
  audios = [audios]
 
205
 
206
  # Run inference
207
  results = []
208
+ progress_bar(0, desc="Starting")
209
+ for audio, query in progress_bar.tqdm(
210
+ zip(audios, queries), desc="Generating responses", total=len(audios)
211
+ ):
212
  output = sliding_window_inference(
213
  audio,
214
  query,
 
228
  def parse_args() -> argparse.Namespace:
229
  parser = argparse.ArgumentParser("Run NatureLM-audio inference")
230
  parser.add_argument(
231
+ "-a",
232
+ "--audio",
233
+ type=str,
234
+ required=True,
235
+ help="Path to an audio file or a directory containing audio files",
236
+ )
237
+ parser.add_argument(
238
+ "-q", "--query", type=str, required=True, help="Query for the model"
239
  )
 
240
  parser.add_argument(
241
  "--cfg-path",
242
  type=str,
243
  default="configs/inference.yml",
244
  help="Path to the configuration file for the model",
245
  )
 
246
  parser.add_argument(
247
+ "--output_path",
248
+ type=str,
249
+ default="inference_output.jsonl",
250
+ help="Output path for the results",
251
+ )
252
+ parser.add_argument(
253
+ "--window_length_seconds",
254
+ type=float,
255
+ default=10.0,
256
+ help="Length of the sliding window in seconds",
257
  )
258
  parser.add_argument(
259
+ "--hop_length_seconds",
260
+ type=float,
261
+ default=10.0,
262
+ help="Hop length for the sliding window in seconds",
263
  )
264
  args = parser.parse_args()
265
 
 
297
  audio_path = Path(audio_path)
298
  if audio_path.is_dir():
299
  audio_paths = []
300
+ print(
301
+ f"Searching for audio files in {str(audio_path)} with extensions {', '.join(_AUDIO_FILE_EXTENSIONS)}"
302
+ )
303
  for ext in _AUDIO_FILE_EXTENSIONS:
304
  audio_paths.extend(list(audio_path.rglob(f"*{ext}")))
305
 
 
316
  if not query:
317
  raise ValueError("Query cannot be empty")
318
  if not audio_paths:
319
+ raise ValueError(
320
+ "No audio files found. Please check the path or file extensions."
321
+ )
322
 
323
  # Load model and config
324
  model, cfg = load_model_and_config(cfg_path)
325
 
326
  # Load audio processor
327
+ processor = NatureLMAudioProcessor(
328
+ sample_rate=_SAMPLE_RATE, max_length_seconds=_MAX_LENGTH_SECONDS
329
+ )
330
 
331
  # Run inference
332
  results = {"audio_path": [], "output": []}
333
  for path in audio_paths:
334
+ output = sliding_window_inference(
335
+ path,
336
+ query,
337
+ processor,
338
+ model,
339
+ cfg,
340
+ window_length_seconds,
341
+ hop_length_seconds,
342
+ )
343
  results["audio_path"].append(str(path))
344
  results["output"].append(output)
345
  print(f"Processed {path}, model output:\n=======\n{output}\n=======\n")
NatureLM/logger.py DELETED
@@ -1,190 +0,0 @@
1
- import datetime
2
- import logging
3
- import time
4
- from collections import defaultdict, deque
5
-
6
- import torch
7
- import torch.distributed as dist
8
- import wandb
9
-
10
- from NatureLM.dist_utils import is_dist_avail_and_initialized, is_main_process
11
-
12
-
13
- class SmoothedValue(object):
14
- """Track a series of values and provide access to smoothed values over a
15
- window or the global series average.
16
- """
17
-
18
- def __init__(self, window_size=20, fmt=None):
19
- if fmt is None:
20
- fmt = "{median:.4f} ({global_avg:.4f})"
21
- self.deque = deque(maxlen=window_size)
22
- self.total = 0.0
23
- self.count = 0
24
- self.fmt = fmt
25
-
26
- def update(self, value, n=1):
27
- self.deque.append(value)
28
- self.count += n
29
- self.total += value * n
30
-
31
- def synchronize_between_processes(self):
32
- """
33
- Warning: does not synchronize the deque!
34
- """
35
- if not is_dist_avail_and_initialized():
36
- return
37
- t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
38
- dist.barrier()
39
- dist.all_reduce(t)
40
- t = t.tolist()
41
- self.count = int(t[0])
42
- self.total = t[1]
43
-
44
- @property
45
- def median(self):
46
- d = torch.tensor(list(self.deque))
47
- return d.median().item()
48
-
49
- @property
50
- def avg(self):
51
- d = torch.tensor(list(self.deque), dtype=torch.float32)
52
- return d.mean().item()
53
-
54
- @property
55
- def global_avg(self):
56
- return self.total / self.count
57
-
58
- @property
59
- def max(self):
60
- return max(self.deque)
61
-
62
- @property
63
- def value(self):
64
- return self.deque[-1]
65
-
66
- def __str__(self):
67
- return self.fmt.format(
68
- median=self.median,
69
- avg=self.avg,
70
- global_avg=self.global_avg,
71
- max=self.max,
72
- value=self.value,
73
- )
74
-
75
-
76
- class MetricLogger(object):
77
- def __init__(self, delimiter="\t"):
78
- self.meters = defaultdict(SmoothedValue)
79
- self.delimiter = delimiter
80
-
81
- def update(self, **kwargs):
82
- for k, v in kwargs.items():
83
- if isinstance(v, torch.Tensor):
84
- v = v.item()
85
- assert isinstance(v, (float, int))
86
- self.meters[k].update(v)
87
-
88
- def __getattr__(self, attr):
89
- if attr in self.meters:
90
- return self.meters[attr]
91
- if attr in self.__dict__:
92
- return self.__dict__[attr]
93
- raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))
94
-
95
- def __str__(self):
96
- loss_str = []
97
- for name, meter in self.meters.items():
98
- loss_str.append("{}: {}".format(name, str(meter)))
99
- return self.delimiter.join(loss_str)
100
-
101
- def global_avg(self):
102
- loss_str = []
103
- for name, meter in self.meters.items():
104
- loss_str.append("{}: {:.4f}".format(name, meter.global_avg))
105
- return self.delimiter.join(loss_str)
106
-
107
- def synchronize_between_processes(self):
108
- for meter in self.meters.values():
109
- meter.synchronize_between_processes()
110
-
111
- def add_meter(self, name, meter):
112
- self.meters[name] = meter
113
-
114
- def log_every(self, iterable, print_freq, header=None, logger=None, start_step=None):
115
- i = 0
116
- if not header:
117
- header = ""
118
- start_time = time.time()
119
- end = time.time()
120
- iter_time = SmoothedValue(fmt="{avg:.4f}")
121
- data_time = SmoothedValue(fmt="{avg:.4f}")
122
- space_fmt = ":" + str(len(str(len(iterable)))) + "d"
123
- log_msg = [
124
- header,
125
- "[{0" + space_fmt + "}/{1}]",
126
- "eta: {eta}",
127
- "{meters}",
128
- "time: {time}",
129
- "data: {data}",
130
- ]
131
- if torch.cuda.is_available():
132
- log_msg.append("max mem: {memory:.0f}")
133
- log_msg = self.delimiter.join(log_msg)
134
- MB = 1024.0 * 1024.0
135
- for obj in iterable:
136
- data_time.update(time.time() - end)
137
- yield obj
138
- iter_time.update(time.time() - end)
139
- if i % print_freq == 0 or i == len(iterable) - 1:
140
- if is_main_process():
141
- if logger is not None:
142
- assert start_step is not None, "start_step is needed to compute global_step!"
143
- for name, meter in self.meters.items():
144
- logger.add_scalar("{}".format(name), float(str(meter)), global_step=start_step + i)
145
- # Log to wandb
146
- wandb.log({name: float(str(meter)) for name, meter in self.meters.items()}, step=start_step + i)
147
- eta_seconds = iter_time.global_avg * (len(iterable) - i)
148
- eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
149
- if torch.cuda.is_available():
150
- print(
151
- log_msg.format(
152
- i,
153
- len(iterable),
154
- eta=eta_string,
155
- meters=str(self),
156
- time=str(iter_time),
157
- data=str(data_time),
158
- memory=torch.cuda.max_memory_allocated() / MB,
159
- )
160
- )
161
- else:
162
- print(
163
- log_msg.format(
164
- i,
165
- len(iterable),
166
- eta=eta_string,
167
- meters=str(self),
168
- time=str(iter_time),
169
- data=str(data_time),
170
- )
171
- )
172
- i += 1
173
- end = time.time()
174
- total_time = time.time() - start_time
175
- total_time_str = str(datetime.timedelta(seconds=int(total_time)))
176
- print("{} Total time: {} ({:.4f} s / it)".format(header, total_time_str, total_time / len(iterable)))
177
-
178
-
179
- class AttrDict(dict):
180
- def __init__(self, *args, **kwargs):
181
- super(AttrDict, self).__init__(*args, **kwargs)
182
- self.__dict__ = self
183
-
184
-
185
- def setup_logger():
186
- logging.basicConfig(
187
- level=logging.INFO if is_main_process() else logging.WARN,
188
- format="%(asctime)s [%(levelname)s] %(message)s",
189
- handlers=[logging.StreamHandler()],
190
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
NatureLM/models/NatureLM.py CHANGED
@@ -645,7 +645,7 @@ class NatureLM(nn.Module, PyTorchModelHubMixin):
645
  stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
646
 
647
  with torch.autocast(self.device.type, dtype=torch.bfloat16):
648
- outputs = self.llama_model.generate( # TODO: Wrap the llama_model with outlines https://outlines-dev.github.io/outlines/reference/models/transformers/
649
  inputs_embeds=embeds.bfloat16(),
650
  max_new_tokens=generate_cfg.max_new_tokens,
651
  stopping_criteria=stopping_criteria,
 
645
  stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
646
 
647
  with torch.autocast(self.device.type, dtype=torch.bfloat16):
648
+ outputs = self.llama_model.generate(
649
  inputs_embeds=embeds.bfloat16(),
650
  max_new_tokens=generate_cfg.max_new_tokens,
651
  stopping_criteria=stopping_criteria,
NatureLM/optims.py DELETED
@@ -1,154 +0,0 @@
1
- # This script is from https://github.com/salesforce/LAVIS/blob/main/lavis/common/optims.py
2
-
3
- import logging
4
- import math
5
-
6
- import torch
7
-
8
- from NatureLM.config import OptimizerConfig
9
-
10
-
11
- class LinearWarmupStepLRScheduler:
12
- def __init__(
13
- self,
14
- optimizer,
15
- max_epoch,
16
- min_lr,
17
- init_lr,
18
- decay_rate=1,
19
- warmup_start_lr=-1,
20
- warmup_steps=0,
21
- **kwargs,
22
- ):
23
- self.optimizer = optimizer
24
-
25
- self.max_epoch = max_epoch
26
- self.min_lr = min_lr
27
-
28
- self.decay_rate = decay_rate
29
-
30
- self.init_lr = init_lr
31
- self.warmup_steps = warmup_steps
32
- self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
33
-
34
- def step(self, cur_epoch, cur_step):
35
- if cur_epoch == 0:
36
- warmup_lr_schedule(
37
- step=cur_step,
38
- optimizer=self.optimizer,
39
- max_step=self.warmup_steps,
40
- init_lr=self.warmup_start_lr,
41
- max_lr=self.init_lr,
42
- )
43
- else:
44
- step_lr_schedule(
45
- epoch=cur_epoch,
46
- optimizer=self.optimizer,
47
- init_lr=self.init_lr,
48
- min_lr=self.min_lr,
49
- decay_rate=self.decay_rate,
50
- )
51
-
52
-
53
- class LinearWarmupCosineLRScheduler:
54
- def __init__(
55
- self,
56
- optimizer,
57
- max_epoch,
58
- iters_per_epoch,
59
- min_lr,
60
- init_lr,
61
- warmup_steps=0,
62
- warmup_start_lr=-1,
63
- **kwargs,
64
- ):
65
- self.optimizer = optimizer
66
-
67
- self.max_epoch = max_epoch
68
- self.iters_per_epoch = iters_per_epoch
69
- self.min_lr = min_lr
70
-
71
- self.init_lr = init_lr
72
- self.warmup_steps = warmup_steps
73
- self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
74
-
75
- def step(self, cur_epoch, cur_step):
76
- total_cur_step = cur_epoch * self.iters_per_epoch + cur_step
77
- if total_cur_step < self.warmup_steps:
78
- warmup_lr_schedule(
79
- step=cur_step,
80
- optimizer=self.optimizer,
81
- max_step=self.warmup_steps,
82
- init_lr=self.warmup_start_lr,
83
- max_lr=self.init_lr,
84
- )
85
- else:
86
- cosine_lr_schedule(
87
- epoch=total_cur_step,
88
- optimizer=self.optimizer,
89
- max_epoch=self.max_epoch * self.iters_per_epoch,
90
- init_lr=self.init_lr,
91
- min_lr=self.min_lr,
92
- )
93
-
94
-
95
- def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
96
- """Decay the learning rate"""
97
- lr = (init_lr - min_lr) * 0.5 * (1.0 + math.cos(math.pi * epoch / max_epoch)) + min_lr
98
- for param_group in optimizer.param_groups:
99
- param_group["lr"] = lr
100
-
101
-
102
- def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
103
- """Warmup the learning rate"""
104
- lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
105
- for param_group in optimizer.param_groups:
106
- param_group["lr"] = lr
107
-
108
-
109
- def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
110
- """Decay the learning rate"""
111
- lr = max(min_lr, init_lr * (decay_rate**epoch))
112
- for param_group in optimizer.param_groups:
113
- param_group["lr"] = lr
114
-
115
-
116
- def get_optimizer(model, config: OptimizerConfig):
117
- num_parameters = 0
118
- p_wd, p_non_wd = [], []
119
- for n, p in model.named_parameters():
120
- if not p.requires_grad:
121
- continue # frozen weights
122
- print(n)
123
- if p.ndim < 2 or "bias" in n or "ln" in n or "bn" in n:
124
- p_non_wd.append(p)
125
- else:
126
- p_wd.append(p)
127
- num_parameters += p.data.nelement()
128
- logging.info("number of trainable parameters: %d" % num_parameters)
129
- optim_params = [
130
- {
131
- "params": p_wd,
132
- "weight_decay": float(config.weight_decay),
133
- },
134
- {"params": p_non_wd, "weight_decay": 0},
135
- ]
136
- beta2 = config.beta2
137
- if config.device == "cpu":
138
- optimizer = torch.optim.AdamW(
139
- optim_params,
140
- lr=float(config.init_lr),
141
- weight_decay=float(config.weight_decay),
142
- betas=(0.9, beta2),
143
- )
144
- else:
145
- import bitsandbytes as bnb
146
-
147
- optimizer = bnb.optim.PagedAdamW8bit(
148
- optim_params,
149
- lr=float(config.init_lr),
150
- weight_decay=float(config.weight_decay),
151
- betas=(0.9, beta2),
152
- )
153
-
154
- return optimizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
NatureLM/processors.py CHANGED
@@ -6,7 +6,7 @@ from dataclasses import dataclass, field
6
 
7
  import numpy as np
8
  import resampy
9
- import soundfile as sf
10
  import torch
11
 
12
 
@@ -49,7 +49,7 @@ class NatureLMAudioProcessor:
49
  def prepare_audio(self, audio: list[float] | np.ndarray | os.PathLike, input_sr: int = None) -> torch.Tensor:
50
  """Prepare an audio array or file path for inference"""
51
  if isinstance(audio, str | os.PathLike):
52
- audio, sr = sf.read(audio)
53
  input_sr = sr
54
  elif isinstance(audio, list):
55
  audio = np.array(audio)
 
6
 
7
  import numpy as np
8
  import resampy
9
+ import librosa
10
  import torch
11
 
12
 
 
49
  def prepare_audio(self, audio: list[float] | np.ndarray | os.PathLike, input_sr: int = None) -> torch.Tensor:
50
  """Prepare an audio array or file path for inference"""
51
  if isinstance(audio, str | os.PathLike):
52
+ audio, sr = librosa.load(audio, sr=None, mono=False)
53
  input_sr = sr
54
  elif isinstance(audio, list):
55
  audio = np.array(audio)
NatureLM/runner.py DELETED
@@ -1,515 +0,0 @@
1
- # This script is based on https://github.com/salesforce/LAVIS/blob/main/lavis/runners/runner_base.py
2
-
3
- import datetime
4
- import json
5
- import logging
6
- import os
7
- import time
8
- from collections import defaultdict
9
- from pathlib import Path
10
-
11
- import torch
12
- import torch.distributed
13
- import torch.distributed as dist
14
- import wandb
15
- from torch.nn.parallel import DistributedDataParallel as DDP
16
- from torch.utils.tensorboard import SummaryWriter
17
-
18
- from NatureLM.config import Config
19
- from NatureLM.dist_utils import get_rank, get_world_size, is_dist_avail_and_initialized, is_main_process, main_process
20
- from NatureLM.logger import MetricLogger, SmoothedValue
21
- from NatureLM.optims import LinearWarmupCosineLRScheduler, get_optimizer
22
- from NatureLM.task_metrics import get_task_metrics
23
- from NatureLM.utils import get_dataloader, prepare_sample_dist
24
-
25
-
26
- class Runner:
27
- def __init__(self, cfg: Config, model, datasets, job_id):
28
- self.config = cfg
29
-
30
- # log
31
- device = "cuda:0"
32
- if is_main_process():
33
- if self.config.run.wandb_enabled:
34
- wandb.init(project="earthlm", config=self.config.model_dump())
35
- else:
36
- wandb.init(mode="disabled")
37
-
38
- if "LOCAL_RANK" in os.environ:
39
- device = int(os.environ["LOCAL_RANK"])
40
- else:
41
- device = self.config.run.device
42
- print(f"device is {device} could have been {self.config.run.device}")
43
- self.output_dir = Path(self.config.run.output_dir) / job_id
44
- self.output_dir.mkdir(parents=True, exist_ok=True)
45
- self.log_writter = SummaryWriter(self.output_dir)
46
-
47
- # settings
48
- self.device = torch.device(device)
49
- self.use_distributed = self.config.run.use_distributed
50
- self.start_epoch = 0
51
- self.max_epoch = self.config.run.optims.max_epoch
52
- self.evaluate_only = self.config.run.evaluate
53
- self.cuda_enabled = self.device.type == "cuda"
54
-
55
- # test prompt
56
- self.prompt_template = self.config.model.prompt_template
57
-
58
- # model
59
- self._model = model
60
- torch.nn.SyncBatchNorm.convert_sync_batchnorm(self._model)
61
- self._model.to(self.device)
62
- if self.use_distributed:
63
- self.model = DDP(
64
- self._model,
65
- find_unused_parameters=True,
66
- static_graph=False,
67
- device_ids=[self.device],
68
- )
69
- else:
70
- self.model = self._model
71
-
72
- # dataloaders
73
- self.train_loader = get_dataloader(
74
- datasets["train"],
75
- self.config.run,
76
- is_train=True,
77
- use_distributed=self.use_distributed,
78
- )
79
- self.valid_loader = get_dataloader(
80
- datasets["valid"],
81
- self.config.run,
82
- is_train=False,
83
- use_distributed=self.use_distributed,
84
- )
85
- self.test_loader = get_dataloader(
86
- datasets["test"],
87
- self.config.run,
88
- is_train=False,
89
- use_distributed=self.use_distributed,
90
- )
91
-
92
- # scaler
93
- self.use_amp = self.config.run.amp
94
- if self.use_amp:
95
- self.scaler = torch.cuda.amp.GradScaler()
96
- else:
97
- self.scaler = None
98
-
99
- # optimizer & scheduler
100
- self.iters_per_epoch = (
101
- len(self.train_loader) if self.config.run.epoch_based else self.config.run.iters_per_epoch
102
- )
103
- self.optimizer = get_optimizer(self.model, self.config.run.optims)
104
- self.scheduler = LinearWarmupCosineLRScheduler(
105
- self.optimizer,
106
- max_epoch=self.max_epoch,
107
- iters_per_epoch=self.iters_per_epoch,
108
- min_lr=self.config.run.optims.min_lr,
109
- init_lr=self.config.run.optims.init_lr,
110
- warmup_steps=self.config.run.optims.warmup_steps,
111
- warmup_start_lr=self.config.run.optims.warmup_start_lr,
112
- )
113
-
114
- #### augmentations
115
- # self.rng = random.Random(self.config.run.seed)
116
- # self.rngnp = np.random.default_rng(seed=self.config.run.seed)
117
- # self.rngth = torch.Generator(device=args.device)
118
- # self.rngth.manual_seed(self.config.run.seed)
119
- # augments = []
120
- # if self.config.run.augmentations.flip:
121
- # augments.append(augmentations.Flip(self.config.run.augmentations.flip, rngth=self.rngth, seed=self.config.run.seed))
122
- # if self.config.run.augmentations.bandmask:
123
- # augments.append(augmentations.BandMask(self.config.run.augmentations.bandmask, sample_rate=args.sample_rate, rng=self.rng, seed=self.config.run.seed))
124
- # if self.config.run.augmentations.revecho:
125
- # augments.append(
126
- # augmentations.RevEcho(proba=self.config.run.augmentations.revecho,rng=self.rng,seed=self.config.run.seed))
127
- # self.augment = torch.nn.Sequential(*augments)
128
-
129
- self.log_config()
130
-
131
- def unwrap_dist_model(self, model):
132
- if self.use_distributed:
133
- return model.module
134
- else:
135
- return model
136
-
137
- def train_epoch(self, epoch):
138
- self.model.train()
139
-
140
- metric_logger = MetricLogger(delimiter=" ")
141
- metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}"))
142
- metric_logger.add_meter("loss", SmoothedValue(window_size=1, fmt="{value:.4f}"))
143
-
144
- logging.info("Start training epoch {}, {} iters per inner epoch.".format(epoch, self.iters_per_epoch))
145
- header = "Train: data epoch: [{}]".format(epoch)
146
-
147
- # Get gradient clipping parameters from config
148
- clip_grad_norm = self.config.run.optims.max_grad_norm
149
- clip_grad_value = self.config.run.optims.max_grad_value
150
-
151
- for i in metric_logger.log_every(
152
- range(self.iters_per_epoch),
153
- self.config.run.log_freq,
154
- header=header,
155
- logger=self.log_writter,
156
- start_step=epoch * self.iters_per_epoch,
157
- ):
158
- if i >= self.iters_per_epoch:
159
- break
160
-
161
- samples = next(self.train_loader)
162
-
163
- samples = prepare_sample_dist(samples, self.device)
164
-
165
- #### augmentation
166
- # if False:
167
- # samples = self.augment(samples)
168
-
169
- self.scheduler.step(cur_epoch=epoch, cur_step=i)
170
-
171
- with torch.autocast(self.device.type, enabled=self.use_amp, dtype=torch.bfloat16):
172
- loss = self.model(samples)["loss"]
173
- if torch.isnan(loss):
174
- print("loss nan", samples)
175
- # continue
176
-
177
- if self.use_amp and self.scaler:
178
- self.scaler.scale(loss).backward()
179
- else:
180
- loss.backward()
181
-
182
- # Apply gradient clipping
183
- if clip_grad_norm is not None:
184
- if self.use_amp and self.scaler:
185
- self.scaler.unscale_(self.optimizer)
186
- torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=clip_grad_norm)
187
- if clip_grad_value is not None:
188
- if self.use_amp and self.scaler:
189
- self.scaler.unscale_(self.optimizer)
190
- torch.nn.utils.clip_grad_value_(self.model.parameters(), clip_value=clip_grad_value)
191
-
192
- if (i + 1) % self.config.run.accum_grad_iters == 0:
193
- if self.use_amp and self.scaler:
194
- self.scaler.step(self.optimizer)
195
- self.scaler.update()
196
- else:
197
- self.optimizer.step()
198
- self.optimizer.zero_grad()
199
-
200
- metric_logger.update(loss=loss.item())
201
- metric_logger.update(lr=self.optimizer.param_groups[0]["lr"])
202
-
203
- metric_logger.synchronize_between_processes()
204
- logging.info("Averaged stats: " + str(metric_logger.global_avg()))
205
- return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
206
-
207
- @torch.no_grad()
208
- def valid_epoch(self, epoch, split, decode=True, save_json=False, decode_ratio=1.0):
209
- """
210
- Decode = True will lead to calculation of custom metrics which are based on text.
211
- decode_ratio controls the percentage of batches which will have custom metrics computed,
212
- a speed trade-off due to the cost of the 'generate' method.
213
- """
214
- model = self.unwrap_dist_model(self.model)
215
- model.eval()
216
-
217
- dataloader = getattr(self, split + "_loader", None)
218
- assert dataloader is not None, f"{split}_loader does not exist."
219
-
220
- metric_logger = MetricLogger(delimiter=" ")
221
- header = f"Eval: data epoch: [{epoch}]"
222
-
223
- results_per_task = defaultdict(list) # Store results per task
224
- overall_results = [] # Store all results for overall metrics
225
-
226
- # Calculate N based on decode_ratio
227
- if decode_ratio <= 0.0:
228
- N = float("inf") # Effectively never run generate
229
- elif decode_ratio >= 1.0:
230
- N = 1 # Run generate every batch
231
- else:
232
- N = max(int(1 / decode_ratio), 1) # Ensure N is at least 1
233
-
234
- batch_idx = 0
235
-
236
- # Initialize overall metrics
237
- overall_res = {
238
- "loss": torch.tensor(0.0, device=self.device),
239
- "correct": torch.tensor(0.0, device=self.device),
240
- "total": torch.tensor(0.0, device=self.device),
241
- }
242
-
243
- # Initialize per-task metrics
244
- per_task_res = defaultdict(
245
- lambda: {
246
- "loss": torch.tensor(0.0, device=self.device),
247
- "correct": torch.tensor(0.0, device=self.device),
248
- "total": torch.tensor(0.0, device=self.device),
249
- "n_sample": 0,
250
- "predicted_texts": [],
251
- "gold_texts": [],
252
- }
253
- )
254
-
255
- for samples in metric_logger.log_every(dataloader, self.config.run.log_freq, header=header):
256
- samples = prepare_sample_dist(samples, self.device)
257
-
258
- with torch.autocast(self.device.type, enabled=self.use_amp):
259
- forward_result = model(samples, verbose=True)
260
-
261
- # Extract batch-level loss and correct counts
262
- batch_loss = forward_result.get("loss", torch.tensor(0.0, device=self.device))
263
- batch_correct = forward_result.get("correct", torch.tensor(0.0, device=self.device))
264
- batch_total = forward_result.get("total", torch.tensor(1.0, device=self.device))
265
-
266
- batch_size = len(samples["id"])
267
-
268
- # Update overall metrics with batch-level values
269
- overall_res["loss"] += batch_loss.detach()
270
- overall_res["correct"] += batch_correct.detach()
271
- overall_res["total"] += batch_total.detach()
272
-
273
- # Decide whether to run generate based on decode_ratio
274
- if decode and (batch_idx % N == 0):
275
- prompts = samples.get("prompt", None)
276
- try:
277
- generated_texts = model.generate(samples, self.config.generate, prompts=prompts)
278
- except Exception as e:
279
- print("error in generation", e)
280
- generated_texts = [None] * batch_size
281
- else:
282
- generated_texts = [None] * batch_size # Placeholder if not decoding
283
-
284
- # Process per-sample data for per-task metrics and result saving
285
- for i in range(batch_size):
286
- task = samples["task"][i]
287
-
288
- # Collect per-task batch-level metrics
289
- per_task_res[task]["loss"] += batch_loss.detach()
290
- per_task_res[task]["correct"] += batch_correct.detach()
291
- per_task_res[task]["total"] += batch_total.detach()
292
- per_task_res[task]["n_sample"] += 1
293
-
294
- res = {
295
- "id": samples["id"][i],
296
- "ground_truth": samples["text"][i], # Gold label from dataloader
297
- "task": task,
298
- "predicted_text": generated_texts[i],
299
- }
300
-
301
- if decode and generated_texts[i] is not None:
302
- res["prompt"] = samples.get("prompt", [None])[i]
303
-
304
- results_per_task[task].append(res)
305
- overall_results.append(res)
306
-
307
- # Collect texts for custom metrics
308
- if generated_texts[i] is not None:
309
- per_task_res[task]["predicted_texts"].append(generated_texts[i])
310
- per_task_res[task]["gold_texts"].append(samples["text"][i])
311
-
312
- batch_idx += 1 # Increment batch index
313
-
314
- if save_json:
315
- for task, task_results in results_per_task.items():
316
- self.save_result(task_results, self.output_dir, f"eval_{split}_{task}_epoch_{epoch}")
317
- # Optionally save overall results
318
- self.save_result(overall_results, self.output_dir, f"eval_{split}_epoch_{epoch}")
319
-
320
- # Synchronize metrics across processes if in distributed mode
321
- if is_dist_avail_and_initialized():
322
- for key in overall_res:
323
- dist.all_reduce(overall_res[key])
324
-
325
- overall_ret = {
326
- "loss": (overall_res["loss"] / batch_idx).item(),
327
- "agg_metrics": (overall_res["correct"] / overall_res["total"]).item(),
328
- }
329
-
330
- if is_main_process():
331
- # Log overall metrics
332
- wandb.log(
333
- {
334
- f"{split}_loss": overall_ret["loss"],
335
- f"{split}_accuracy": overall_ret["agg_metrics"],
336
- "epoch": epoch,
337
- }
338
- )
339
-
340
- # Compute and log per-task metrics
341
- for task, res in per_task_res.items():
342
- if "caption-none" in task:
343
- continue
344
-
345
- if self.use_distributed:
346
- print(f"Rank {dist.get_rank()}, task={task}, ")
347
-
348
- print(
349
- f"loss={res['loss'].shape, res['loss'].dtype}, "
350
- f"correct={res['correct'].shape, res['correct'].dtype}, "
351
- f"total={res['total'].shape, res['total'].dtype}, "
352
- f"n_sample={res['n_sample']}"
353
- )
354
-
355
- # Synchronize metrics across processes if in distributed mode
356
- if is_dist_avail_and_initialized():
357
- dist.all_reduce(res["loss"])
358
- dist.all_reduce(res["correct"])
359
- dist.all_reduce(res["total"])
360
- dist.all_reduce(torch.tensor(res["n_sample"], device=self.device))
361
-
362
- ret = {
363
- "loss": (res["loss"] / res["n_sample"]).item(),
364
- "agg_metrics": (res["correct"] / res["total"]).item(),
365
- }
366
-
367
- if is_main_process():
368
- # Log per-task metrics
369
- wandb.log(
370
- {
371
- f"{split}_{task}_loss": ret["loss"],
372
- f"{split}_{task}_accuracy": ret["agg_metrics"],
373
- "epoch": epoch,
374
- }
375
- )
376
-
377
- # Get and compute custom metrics for this task
378
- metrics_list = get_task_metrics(task)
379
- predicted_texts = res["predicted_texts"]
380
- gold_texts = res["gold_texts"]
381
- for metric in metrics_list:
382
- if predicted_texts and gold_texts:
383
- metric_value = metric.compute_metric(predicted_texts, gold_texts)
384
- metric_name = metric.__class__.__name__
385
- wandb.log(
386
- {
387
- f"{split}_{task}_{metric_name}": metric_value,
388
- "epoch": epoch,
389
- }
390
- )
391
- return overall_ret # Return overall metrics
392
-
393
- def save_result(self, result, result_dir, filename):
394
- result_file = os.path.join(result_dir, "%s_rank%d.json" % (filename, get_rank()))
395
- final_result_file = os.path.join(result_dir, "%s.json" % filename)
396
-
397
- try:
398
- json.dump(result, open(result_file, "w"), ensure_ascii=False)
399
- except Exception as e:
400
- logging.warning(f"Error saving {result_file}. Error: {e}")
401
- json.dump(result, open(result_file, "w", encoding="utf-8"), ensure_ascii=False)
402
-
403
- # if is_dist_avail_and_initialized():
404
- # dist.barrier()
405
-
406
- if is_main_process():
407
- logging.info("rank %d starts merging results." % get_rank())
408
- result = []
409
-
410
- for rank in range(get_world_size()):
411
- result_file = os.path.join(result_dir, "%s_rank%d.json" % (filename, rank))
412
- try:
413
- res = json.load(open(result_file, "r"))
414
- except Exception as e:
415
- logging.warning(f"Error reading {result_file}. Error: {e}")
416
- res = json.load(open(result_file, "r", encoding="utf-8"))
417
- result += res
418
-
419
- try:
420
- json.dump(result, open(final_result_file, "w"), ensure_ascii=False)
421
- except Exception as e:
422
- logging.warning(f"Error saving {final_result_file}. Error: {e}")
423
- json.dump(
424
- result,
425
- open(final_result_file, "w", encoding="utf-8"),
426
- ensure_ascii=False,
427
- )
428
-
429
- print("result file saved to %s" % final_result_file)
430
-
431
- def train(self):
432
- start_time = time.time()
433
- best_agg_metric = 0
434
- best_epoch = 0
435
-
436
- for cur_epoch in range(self.start_epoch, self.max_epoch):
437
- if self.evaluate_only:
438
- break
439
-
440
- # training phase
441
- logging.info("Training Phase")
442
- train_stats = self.train_epoch(cur_epoch)
443
- self.log_stats(train_stats, split_name="train")
444
-
445
- # validating phase
446
- logging.info("Validating Phase")
447
- valid_log = self.valid_epoch(
448
- cur_epoch,
449
- "valid",
450
- decode=self.config.run.custom_metrics,
451
- save_json=False,
452
- decode_ratio=self.config.run.decode_ratio,
453
- )
454
- if valid_log is not None:
455
- if is_main_process():
456
- agg_metrics = valid_log["agg_metrics"]
457
- if agg_metrics > best_agg_metric:
458
- best_agg_metric = agg_metrics
459
- best_epoch = cur_epoch
460
- self.save_checkpoint(cur_epoch, is_best=True)
461
-
462
- valid_log.update({"best_epoch": best_epoch})
463
- self.log_stats(valid_log, split_name="valid")
464
- self.save_checkpoint(cur_epoch, is_best=False)
465
-
466
- # if self.use_distributed:
467
- # dist.barrier()
468
-
469
- # testing phase
470
- if self.evaluate_only:
471
- self.valid_epoch("best", "test", decode=True, save_json=True)
472
-
473
- total_time = time.time() - start_time
474
- total_time_str = str(datetime.timedelta(seconds=int(total_time)))
475
- logging.info("Training time {}".format(total_time_str))
476
-
477
- @main_process
478
- def log_config(self):
479
- with open(os.path.join(self.output_dir, "log.txt"), "a") as f:
480
- f.write(json.dumps(self.config.model_dump_json(), indent=4) + "\n")
481
-
482
- @main_process
483
- def log_stats(self, stats, split_name):
484
- if isinstance(stats, dict):
485
- log_stats = {**{f"{split_name}_{k}": v for k, v in stats.items()}}
486
- with open(os.path.join(self.output_dir, "log.txt"), "a") as f:
487
- f.write(json.dumps(log_stats) + "\n")
488
- elif isinstance(stats, list):
489
- pass
490
-
491
- @main_process
492
- def save_checkpoint(self, cur_epoch, is_best=False):
493
- """
494
- Save the checkpoint at the current epoch.
495
- """
496
- model_no_ddp = self.unwrap_dist_model(self.model)
497
- param_grad_dic = {k: v.requires_grad for (k, v) in model_no_ddp.named_parameters()}
498
- state_dict = model_no_ddp.state_dict()
499
- for k in list(state_dict.keys()):
500
- if k in param_grad_dic.keys() and not param_grad_dic[k]:
501
- # delete parameters that do not require gradient
502
- del state_dict[k]
503
- save_obj = {
504
- "model": state_dict,
505
- "optimizer": self.optimizer.state_dict(),
506
- "config": dict(self.config),
507
- "scaler": self.scaler.state_dict() if self.scaler else None,
508
- "epoch": cur_epoch,
509
- }
510
- save_to = os.path.join(
511
- self.output_dir,
512
- "checkpoint_{}.pth".format("best" if is_best else cur_epoch),
513
- )
514
- logging.info("Saving checkpoint at epoch {} to {}.".format(cur_epoch, save_to))
515
- torch.save(save_obj, save_to)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
NatureLM/storage_utils.py DELETED
@@ -1,26 +0,0 @@
1
- import logging
2
- import os
3
- from functools import lru_cache
4
- from typing import Union
5
-
6
- import cloudpathlib
7
- from google.cloud.storage.client import Client
8
-
9
- logger = logging.getLogger(__name__)
10
-
11
-
12
- def is_gcs_path(path: Union[str, os.PathLike]) -> bool:
13
- return str(path).startswith("gs://")
14
-
15
-
16
- @lru_cache(maxsize=1)
17
- def _get_client():
18
- return cloudpathlib.GSClient(storage_client=Client())
19
-
20
-
21
- try:
22
- _gcp_storage_client = _get_client()
23
- except Exception:
24
- logger.warning("Failed to initialize GCS client." "Training wont be able to use GSPath or R2Path without a client.")
25
- _gcp_storage_client = None
26
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
NatureLM/task_metric_utils.py DELETED
@@ -1,283 +0,0 @@
1
- # Taken from DCASE 2021 Task 5 evaluation source code
2
- # https://github.com/c4dm/dcase-few-shot-bioacoustic
3
- # MIT License
4
-
5
- import mir_eval
6
- import numpy as np
7
- import scipy
8
-
9
-
10
- def fast_intersect(ref, est):
11
- """Find all intersections between reference events and estimated events (fast).
12
- Best-case complexity: O(N log N + M log M) where N=length(ref) and M=length(est)
13
- Parameters
14
- ----------
15
- ref: np.ndarray [shape=(2, n)], real-valued
16
- Array of reference events. Each column is an event.
17
- The first row denotes onset times and the second row denotes offset times.
18
- est: np.ndarray [shape=(2, m)], real-valued
19
- Array of estimated events. Each column is an event.
20
- The first row denotes onset times and the second row denotes offset times.
21
- Returns
22
- -------
23
- matches: list of sets, length n, integer-valued
24
- Property: matches[i] contains the set of all indices j such that
25
- (ref[0, i]<=est[1, j]) AND (ref[1, i]>=est[0, j])
26
- """
27
- ref_on_argsort = np.argsort(ref[0, :])
28
- ref_off_argsort = np.argsort(ref[1, :])
29
-
30
- est_on_argsort = np.argsort(est[0, :])
31
- est_off_argsort = np.argsort(est[1, :])
32
-
33
- est_on_maxindex = est.shape[1]
34
- est_off_minindex = 0
35
- estref_matches = [set()] * ref.shape[1]
36
- refest_matches = [set()] * ref.shape[1]
37
- for ref_id in range(ref.shape[1]):
38
- ref_onset = ref[0, ref_on_argsort[ref_id]]
39
- est_off_sorted = est[1, est_off_argsort[est_off_minindex:]]
40
- search_result = np.searchsorted(est_off_sorted, ref_onset, side="left")
41
- est_off_minindex += search_result
42
- refest_match = est_off_argsort[est_off_minindex:]
43
- refest_matches[ref_on_argsort[ref_id]] = set(refest_match)
44
-
45
- ref_offset = ref[1, ref_off_argsort[-1 - ref_id]]
46
- est_on_sorted = est[0, est_on_argsort[: (1 + est_on_maxindex)]]
47
- search_result = np.searchsorted(est_on_sorted, ref_offset, side="right")
48
- est_on_maxindex = search_result - 1
49
- estref_match = est_on_argsort[: (1 + est_on_maxindex)]
50
- estref_matches[ref_off_argsort[-1 - ref_id]] = set(estref_match)
51
-
52
- zip_iterator = zip(refest_matches, estref_matches)
53
- matches = [x.intersection(y) for (x, y) in zip_iterator]
54
- return matches
55
-
56
-
57
- def iou(ref, est, method="fast"):
58
- """Compute pairwise "intersection over union" (IOU) metric between reference
59
- events and estimated events.
60
- Let us denote by a_i and b_i the onset and offset of reference event i.
61
- Let us denote by u_j and v_j the onset and offset of estimated event j.
62
- The IOU between events i and j is defined as
63
- (min(b_i, v_j)-max(a_i, u_j)) / (max(b_i, v_j)-min(a_i, u_j))
64
- if the events are non-disjoint, and equal to zero otherwise.
65
- Parameters
66
- ----------
67
- ref: np.ndarray [shape=(2, n)], real-valued
68
- Array of reference events. Each column is an event.
69
- The first row denotes onset times and the second row denotes offset times.
70
- est: np.ndarray [shape=(2, m)], real-valued
71
- Array of estimated events. Each column is an event.
72
- The first row denotes onset times and the second row denotes offset times.
73
- method: str, optional.
74
- If "fast" (default), computes pairwise intersections via a custom
75
- dynamic programming algorithm, see fast_intersect.
76
- If "slow", computes pairwise intersections via bruteforce quadratic
77
- search, see slow_intersect.
78
- Returns
79
- -------
80
- S: scipy.sparse.dok.dok_matrix, real-valued
81
- Sparse 2-D matrix. S[i,j] contains the IOU between ref[i] and est[j]
82
- if these events are non-disjoint and zero otherwise.
83
- """
84
- n_refs = ref.shape[1]
85
- n_ests = est.shape[1]
86
- S = scipy.sparse.dok_matrix((n_refs, n_ests))
87
-
88
- if method == "fast":
89
- matches = fast_intersect(ref, est)
90
- elif method == "slow":
91
- matches = slow_intersect(ref, est)
92
-
93
- for ref_id in range(n_refs):
94
- matching_ests = matches[ref_id]
95
- ref_on = ref[0, ref_id]
96
- ref_off = ref[1, ref_id]
97
-
98
- for matching_est_id in matching_ests:
99
- est_on = est[0, matching_est_id]
100
- est_off = est[1, matching_est_id]
101
- intersection = min(ref_off, est_off) - max(ref_on, est_on)
102
- union = max(ref_off, est_off) - min(ref_on, est_on)
103
- intersection_over_union = intersection / union
104
- S[ref_id, matching_est_id] = intersection_over_union
105
-
106
- return S
107
-
108
- def compute_intersection(ref, est, method="fast"):
109
- """Compute pairwise intersection between reference
110
- events and estimated events.
111
- Let us denote by a_i and b_i the onset and offset of reference event i.
112
- Let us denote by u_j and v_j the onset and offset of estimated event j.
113
- The Intersection between events i and j is defined as
114
- (min(b_i, v_j)-max(a_i, u_j))
115
- if the events are non-disjoint, and equal to zero otherwise.
116
- Parameters
117
- ----------
118
- ref: np.ndarray [shape=(2, n)], real-valued
119
- Array of reference events. Each column is an event.
120
- The first row denotes onset times and the second row denotes offset times.
121
- est: np.ndarray [shape=(2, m)], real-valued
122
- Array of estimated events. Each column is an event.
123
- The first row denotes onset times and the second row denotes offset times.
124
- method: str, optional.
125
- If "fast" (default), computes pairwise intersections via a custom
126
- dynamic programming algorithm, see fast_intersect.
127
- If "slow", computes pairwise intersections via bruteforce quadratic
128
- search, see slow_intersect.
129
- Returns
130
- -------
131
- S: scipy.sparse.dok.dok_matrix, real-valued
132
- Sparse 2-D matrix. S[i,j] contains the Intersection between ref[i] and est[j]
133
- if these events are non-disjoint and zero otherwise.
134
- """
135
- n_refs = ref.shape[1]
136
- n_ests = est.shape[1]
137
- S = scipy.sparse.dok_matrix((n_refs, n_ests))
138
-
139
- if method == "fast":
140
- matches = fast_intersect(ref, est)
141
- elif method == "slow":
142
- matches = slow_intersect(ref, est)
143
-
144
- for ref_id in range(n_refs):
145
- matching_ests = matches[ref_id]
146
- ref_on = ref[0, ref_id]
147
- ref_off = ref[1, ref_id]
148
-
149
- for matching_est_id in matching_ests:
150
- est_on = est[0, matching_est_id]
151
- est_off = est[1, matching_est_id]
152
- intersection = min(ref_off, est_off) - max(ref_on, est_on)
153
- # union = max(ref_off, est_off) - min(ref_on, est_on)
154
- # intersection_over_union = intersection / union
155
- S[ref_id, matching_est_id] = intersection #_over_union
156
-
157
- return S
158
-
159
-
160
- def match_events(ref, est, min_iou=0.0, method="fast"):
161
- """
162
- Compute a maximum matching between reference and estimated event times,
163
- subject to a criterion of minimum intersection-over-union (IOU).
164
- Given two lists of events ``ref`` (reference) and ``est`` (estimated),
165
- we seek the largest set of correspondences ``(ref[i], est[j])`` such that
166
- ``iou(ref[i], est[j]) <= min_iou``
167
- and such that each ``ref[i]`` and ``est[j]`` is matched at most once.
168
- This function is strongly inspired by mir_eval.onset.util.match_events.
169
- It relies on mir_eval's implementation of the Hopcroft-Karp algorithm from
170
- maximum bipartite graph matching. However, one important difference is that
171
- mir_eval's distance function relies purely on onset times, whereas this function
172
- considers both onset times and offset times to compute the IOU metric between
173
- reference events and estimated events.
174
- Parameters
175
- ----------
176
- ref: np.ndarray [shape=(2, n)], real-valued
177
- Array of reference events. Each column is an event.
178
- The first row denotes onset times and the second row denotes offset times.
179
- est: np.ndarray [shape=(2, m)], real-valued
180
- Array of estimated events. Each column is an event.
181
- The first row denotes onset times and the second row denotes offset times.
182
- min_iou: real number in [0, 1). Default: 0.
183
- Threshold for minimum amount of intersection over union (IOU) to match
184
- any two events. See the iou method for implementation details.
185
- method: str, optional.
186
- If "fast" (default), computes pairwise intersections via a custom
187
- dynamic programming algorithm, see fast_intersect.
188
- If "slow", computes pairwise intersections via bruteforce quadratic
189
- search, see slow_intersect.
190
- Returns
191
- -------
192
- matching : list of tuples
193
- Every tuple corresponds to a match between one reference event and
194
- one estimated event.
195
- ``matching[i] == (i, j)`` where ``ref[i]`` matches ``est[j]``.
196
- Note that all values i and j appear at most once in the list.
197
- """
198
-
199
- # Intersect reference events and estimated events
200
- S = iou(ref, est, method=method)
201
-
202
- # Threshold intersection-over-union (IOU) ratio
203
- S_bool = scipy.sparse.dok_matrix(S > min_iou)
204
- hits = S_bool.keys()
205
-
206
- # Construct the bipartite graph
207
- G = {}
208
- for ref_i, est_i in hits:
209
- if est_i not in G:
210
- G[est_i] = []
211
- G[est_i].append(ref_i)
212
-
213
- # Apply Hopcroft-Karp algorithm (from mir_eval package)
214
- # to obtain maximum bipartite graph matching
215
- matching = sorted(mir_eval.util._bipartite_match(G).items())
216
- return matching
217
-
218
-
219
- def slow_intersect(ref, est):
220
- """Find all intersections between reference events and estimated events (slow).
221
- Best-case complexity: O(N*M) where N=ref.shape[1] and M=est.shape[1]
222
- Parameters
223
- ----------
224
- ref: np.ndarray [shape=(2, n)], real-valued
225
- Array of reference events. Each column is an event.
226
- The first row denotes onset times and the second row denotes offset times.
227
- est: np.ndarray [shape=(2, m)], real-valued
228
- Array of estimated events. Each column is an event.
229
- The first row denotes onset times and the second row denotes offset times.
230
- Returns
231
- -------
232
- matches: list of sets, length n, integer-valued
233
- Property: matches[i] contains the set of all indices j such that
234
- (ref[0, i]<=est[1, j]) AND (ref[1, i]>=est[0, j])
235
- """
236
- matches = []
237
- for i in range(ref.shape[1]):
238
- matches.append(
239
- set(
240
- [
241
- j
242
- for j in range(est.shape[1])
243
- if ((ref[0, i] <= est[1, j]) and (ref[1, i] >= est[0, j]))
244
- ]
245
- )
246
- )
247
- return
248
-
249
-
250
- def frames_to_st_dict(x, sr=16000):
251
- # x : Tensor of shape (batch, time) or (time,). Entries are 2 (POS), 1 (UNK), and 0 (NEG).
252
- # returns a list of dicts {"Begin Time (s)" : [...], "End Time (s)" : [...], "Annotation" : [...]} if batch dim exists, or a single dict
253
-
254
- if len(x.size()) == 2:
255
- outs = []
256
- for i in range(x.size(0)):
257
- x_sub = x[i,:]
258
- outs.append(_frames_to_st_dict_single(x_sub, sr=sr))
259
- return outs
260
- else:
261
- return _frames_to_st_dict_single(x, sr=sr)
262
-
263
- def _frames_to_st_dict_single(x, sr=16000):
264
- d = {"Begin Time (s)" : [], "End Time (s)" : [], "Annotation" : []}
265
-
266
- for label_i in [1,2]:
267
-
268
- labels = x.numpy() == label_i # POS : 2, UNK : 1, NEG : 0
269
-
270
- starts = np.where((~labels[:-1]) & (labels[1:]))[0] + 1
271
- if labels[0]:
272
- starts = np.insert(starts, 0, 0)
273
-
274
- ends = np.where((labels[:-1]) & (~labels[1:]))[0] + 1
275
- if labels[-1]:
276
- ends = np.append(ends, len(labels))
277
-
278
- for start, end in zip(starts, ends):
279
- d["Begin Time (s)"].append(start/sr)
280
- d["End Time (s)"].append(end/sr)
281
- d["Annotation"].append("POS" if label_i == 2 else "UNK")
282
-
283
- return d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
NatureLM/task_metrics.py DELETED
@@ -1,128 +0,0 @@
1
- import re
2
- from abc import ABC, abstractmethod
3
- from typing import List, Tuple
4
-
5
- import numpy as np
6
-
7
- from NatureLM.task_metric_utils import match_events
8
-
9
- # Assume the following functions are imported from the reference implementations:
10
- # - match_events
11
- # - iou
12
- # - fast_intersect
13
- # - slow_intersect
14
- # - compute_intersection
15
-
16
-
17
- class Metric(ABC):
18
- @abstractmethod
19
- def compute_metric(self, predicted_texts: List[str], gold_texts: List[str]) -> float:
20
- pass
21
-
22
-
23
- class ExactAccuracy(Metric):
24
- """Exact-match accuracy metric."""
25
-
26
- def compute_metric(self, predicted_texts: List[str], gold_texts: List[str]) -> float:
27
- predicted_texts = [pt.lower().strip() for pt in predicted_texts]
28
- gold_texts = [gt.lower().strip() for gt in gold_texts]
29
- correct = sum(p == g for p, g in zip(predicted_texts, gold_texts))
30
- return correct / len(gold_texts) if gold_texts else 0.0
31
-
32
-
33
- class FewShot(Metric):
34
- """Few-shot learning metric based on event matching using IoU."""
35
-
36
- def compute_metric(self, predicted_texts: List[str], gold_texts: List[str]) -> float:
37
- # Initialize counts
38
- total_TP = 0
39
- total_FP = 0
40
- total_FN = 0
41
-
42
- for pred_text, gold_text in zip(predicted_texts, gold_texts):
43
- # Extract events from texts
44
- pred_events = parse_timestamps_from_text(pred_text)
45
- gold_events = parse_timestamps_from_text(gold_text)
46
-
47
- # Convert events to numpy arrays for match_events function
48
- # Each event is (start_time, end_time), need to transpose to shape (2, n)
49
- pred_array = np.array(pred_events).T if pred_events else np.empty((2, 0))
50
- gold_array = np.array(gold_events).T if gold_events else np.empty((2, 0))
51
-
52
- # Use match_events function from the reference implementation
53
- matches = match_events(gold_array, pred_array, min_iou=0.5, method="fast")
54
-
55
- TP = len(matches)
56
- FP = len(pred_events) - TP
57
- FN = len(gold_events) - TP
58
-
59
- total_TP += TP
60
- total_FP += FP
61
- total_FN += FN
62
-
63
- # Compute precision, recall, and F1 score
64
- precision = total_TP / (total_TP + total_FP) if (total_TP + total_FP) > 0 else 0.0
65
- recall = total_TP / (total_TP + total_FN) if (total_TP + total_FN) > 0 else 0.0
66
- f1_score = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
67
-
68
- return f1_score
69
-
70
-
71
- class NoneAccuracy(Metric):
72
- """Accuracy for cases where 'None' is the correct answer."""
73
-
74
- def compute_metric(self, predicted_texts: List[str], gold_texts: List[str]) -> float:
75
- # Normalize texts
76
- predicted_texts = [pt.lower().strip() for pt in predicted_texts]
77
- gold_texts = [gt.lower().strip() for gt in gold_texts]
78
- # Filter indices where gold_text is 'none'
79
- indices = [i for i, gt in enumerate(gold_texts) if gt == "none"]
80
- if not indices:
81
- return 0.0 # No 'None' cases in gold_texts
82
- correct = sum(predicted_texts[i] == "none" for i in indices)
83
- return correct / len(indices)
84
-
85
-
86
- class MultipleSpeciesAccuracy(Metric):
87
- """Accuracy for cases where the correct answer has at least one comma (multiple species)."""
88
-
89
- def compute_metric(self, predicted_texts: List[str], gold_texts: List[str]) -> float:
90
- # Normalize texts
91
- predicted_texts = [pt.lower().strip() for pt in predicted_texts]
92
- gold_texts = [gt.lower().strip() for gt in gold_texts]
93
- # Filter indices where gold_text contains at least one comma
94
- indices = [i for i, gt in enumerate(gold_texts) if "," in gt]
95
- if not indices:
96
- return 0.0 # No multiple-species cases in gold_texts
97
- correct = sum(predicted_texts[i] == gold_texts[i] for i in indices)
98
- return correct / len(indices)
99
-
100
-
101
- def get_task_metrics(task: str) -> List[Metric]:
102
- """Get a list of metric instances appropriate for the given task."""
103
- all_metrics = []
104
- metrics_dict = {}
105
-
106
- if "classification" in task:
107
- metrics_dict["ExactAccuracy"] = ExactAccuracy()
108
- if "fewshot" in task:
109
- metrics_dict["FewShot"] = FewShot()
110
- if "detection" in task:
111
- metrics_dict["ExactAccuracy"] = ExactAccuracy() # Ensures no duplicate
112
- metrics_dict["NoneAccuracy"] = NoneAccuracy()
113
- metrics_dict["MultipleSpeciesAccuracy"] = MultipleSpeciesAccuracy()
114
-
115
- all_metrics = list(metrics_dict.values())
116
- return all_metrics
117
-
118
-
119
- def parse_timestamps_from_text(text: str) -> List[Tuple[float, float]]:
120
- """
121
- Function to parse timestamps from text.
122
- Extracts timestamps in the format "start-end" where start and end are floats.
123
- """
124
- # Regular expression to extract timestamps in the format "start-end"
125
- pattern = r"(\d+\.\d+)-(\d+\.\d+)"
126
- matches = re.findall(pattern, text)
127
- events = [(float(start), float(end)) for start, end in matches]
128
- return events
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
NatureLM/utils.py CHANGED
@@ -25,9 +25,7 @@ import soundfile as sf
25
  import torch
26
  import torch.nn.functional as F
27
  import torchaudio
28
- from torch.utils.data import DataLoader, DistributedSampler
29
-
30
- from NatureLM.dist_utils import get_rank, get_world_size
31
 
32
  logger = logging.getLogger(__name__)
33
 
@@ -99,29 +97,6 @@ def now_as_str() -> str:
99
  return datetime.now().strftime("%Y%m%d%H%M")
100
 
101
 
102
- def get_dataloader(dataset, config, is_train=True, use_distributed=True):
103
- if use_distributed:
104
- sampler = DistributedSampler(dataset, shuffle=is_train, num_replicas=get_world_size(), rank=get_rank())
105
- else:
106
- sampler = None
107
-
108
- loader = DataLoader(
109
- dataset,
110
- batch_size=config.batch_size_train if is_train else config.batch_size_eval,
111
- num_workers=config.num_workers,
112
- pin_memory=False,
113
- sampler=sampler,
114
- shuffle=sampler is None and is_train,
115
- collate_fn=dataset.collater,
116
- drop_last=is_train,
117
- )
118
-
119
- if is_train:
120
- loader = IterLoader(loader, use_distributed=use_distributed)
121
-
122
- return loader
123
-
124
-
125
  def apply_to_sample(f, sample):
126
  if len(sample) == 0:
127
  return {}
 
25
  import torch
26
  import torch.nn.functional as F
27
  import torchaudio
28
+ from torch.utils.data import DataLoader
 
 
29
 
30
  logger = logging.getLogger(__name__)
31
 
 
97
  return datetime.now().strftime("%Y%m%d%H%M")
98
 
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  def apply_to_sample(f, sample):
101
  if len(sample) == 0:
102
  return {}
Space.yaml CHANGED
@@ -1,3 +1,3 @@
1
  sdk: gradio
2
  python_version: 3.10
3
- hardware: cpu
 
1
  sdk: gradio
2
  python_version: 3.10
3
+ hardware: gpu
app.py CHANGED
@@ -1,37 +1,86 @@
1
- import re
2
- import tempfile
3
- from collections import Counter
4
  from pathlib import Path
5
- from typing import Literal, Optional
 
6
 
7
  import gradio as gr
8
  import torch
 
 
9
 
10
  from NatureLM.config import Config
11
  from NatureLM.models.NatureLM import NatureLM
12
- from NatureLM.utils import generate_sample_batches, prepare_sample_waveforms
13
  import spaces
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  class ModelManager:
17
  """Manages model loading and state"""
18
-
19
  def __init__(self):
20
  self.model: Optional[NatureLM] = None
21
  self.config: Optional[Config] = None
22
  self.is_loaded = False
23
  self.is_loading = False
24
  self.load_failed = False
25
-
26
  def check_availability(self) -> tuple[bool, str]:
27
  """Check if the model is available for download"""
28
  try:
29
  from huggingface_hub import model_info
 
30
  info = model_info("EarthSpeciesProject/NatureLM-audio")
31
  return True, "Model is available"
32
  except Exception as e:
33
  return False, f"Model not available: {str(e)}"
34
-
35
  def reset_state(self):
36
  """Reset the model loading state to allow retrying after a failure"""
37
  self.model = None
@@ -39,7 +88,7 @@ class ModelManager:
39
  self.is_loading = False
40
  self.load_failed = False
41
  return self.get_status()
42
-
43
  def get_status(self) -> str:
44
  """Get the current model loading status"""
45
  if self.is_loaded:
@@ -50,34 +99,35 @@ class ModelManager:
50
  return "❌ Model failed to load. Please check the configuration."
51
  else:
52
  return "⏳ Ready to load model on first use"
53
-
54
  def load_model(self) -> Optional[NatureLM]:
55
  """Load the model if needed"""
56
  if self.is_loaded:
57
  return self.model
58
-
59
  if self.is_loading or self.load_failed:
60
  return None
61
-
62
  try:
63
  self.is_loading = True
64
  print("Loading model...")
65
-
66
  # Check if model is available first
67
  available, message = self.check_availability()
68
  if not available:
69
  raise Exception(f"Model not available: {message}")
70
-
71
  model = NatureLM.from_pretrained("EarthSpeciesProject/NatureLM-audio")
72
- model.to("cuda")
73
  model.eval()
74
-
75
- self.model = model
 
76
  self.is_loaded = True
77
  self.is_loading = False
78
  print("Model loaded successfully!")
79
- return model
80
-
81
  except Exception as e:
82
  print(f"Error loading model: {e}")
83
  self.is_loading = False
@@ -88,12 +138,44 @@ class ModelManager:
88
  # Global model manager instance
89
  model_manager = ModelManager()
90
 
91
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  @spaces.GPU
93
- def prompt_lm(audios: list[str], messages: list[dict[str, str]]) -> str:
94
- """Generate response using the model"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  model = model_manager.load_model()
96
-
97
  if model is None:
98
  if model_manager.is_loading:
99
  return "🔄 Loading model... This may take a few minutes on first use. Please try again in a moment."
@@ -101,284 +183,63 @@ def prompt_lm(audios: list[str], messages: list[dict[str, str]]) -> str:
101
  return "❌ Model failed to load. This could be due to:\n• No internet connection\n• Insufficient disk space\n• Model repository access issues\n\nPlease check your connection and try again using the retry button."
102
  else:
103
  return "Demo mode: Model not loaded. Please check the model configuration."
104
-
105
- cuda_enabled = torch.cuda.is_available()
106
- samples = prepare_sample_waveforms(audios, cuda_enabled)
107
- prompt_text = model.llama_tokenizer.apply_chat_template(
108
- messages, tokenize=False, add_generation_prompt=True
109
- ).removeprefix(model.llama_tokenizer.bos_token)
110
-
111
- prompt_text = re.sub(
112
- r"<\|start_header_id\|>system<\|end_header_id\|>\n\nCutting Knowledge Date: [^\n]+\nToday Date: [^\n]+\n\n<\|eot_id\|>",
113
- "",
114
- prompt_text,
115
- )
116
- prompt_text = re.sub("\\n", r"\\n", prompt_text)
117
-
118
- print(f"{prompt_text=}")
119
- with torch.cuda.amp.autocast(dtype=torch.float16):
120
- llm_answer = model.generate(samples, model_manager.config.generate, prompts=[prompt_text])
121
- return llm_answer[0]
122
-
123
-
124
- def _multimodal_textbox_factory():
125
- return gr.MultimodalTextbox(
126
- value=None,
127
- interactive=True,
128
- sources="microphone",
129
- placeholder="Enter message...",
130
- show_label=False,
131
- autofocus=True,
132
- submit_btn="Send"
133
  )
 
134
 
135
 
136
  def user_message(content):
137
  return {"role": "user", "content": content}
138
 
139
 
140
- def add_message(history, message):
141
- for x in message["files"]:
142
- history.append(user_message({"path": x}))
143
- if message["text"]:
144
- history.append(user_message(message["text"]))
145
- return history, _multimodal_textbox_factory()
146
-
147
-
148
- def combine_model_inputs(msgs: list[dict[str, str]]) -> dict[str, list[str]]:
149
- messages = []
150
- files = []
151
- for msg in msgs:
152
- print(msg, messages, files)
153
- match msg:
154
- case {"content": (path,)}:
155
- messages.append({"role": msg["role"], "content": "<Audio><AudioHere></Audio> "})
156
- files.append(path)
157
- case _:
158
- messages.append(msg)
159
-
160
- # Join consecutive messages from the same role
161
- joined_messages = []
162
- for msg in messages:
163
- if joined_messages and joined_messages[-1]["role"] == msg["role"]:
164
- joined_messages[-1]["content"] += msg["content"]
165
- else:
166
- joined_messages.append(msg)
167
-
168
- return {"messages": joined_messages, "files": files}
169
 
170
-
171
- def bot_response(history: list):
172
- print(type(history))
173
- combined_inputs = combine_model_inputs(history)
174
- response = prompt_lm(combined_inputs["files"], combined_inputs["messages"])
175
- history.append({"role": "assistant", "content": response})
176
- return history
177
-
178
-
179
- def _chat_tab(examples):
180
- # Status indicator
181
- status_text = gr.Textbox(
182
- value=model_manager.get_status(),
183
- label="Model Status",
184
- interactive=False,
185
- visible=True
186
  )
187
-
188
- chatbot = gr.Chatbot(
189
- label="Chat",
190
- elem_id="chatbot",
191
- bubble_full_width=False,
192
- type="messages",
193
- render_markdown=False,
194
- resizeable=True
195
- )
196
-
197
- chat_input = _multimodal_textbox_factory()
198
- send_all = gr.Button("Send all", elem_id="send-all")
199
- clear_button = gr.ClearButton(components=[chatbot, chat_input], visible=False)
200
-
201
- chat_input.submit(add_message, [chatbot, chat_input], [chatbot, chat_input])
202
- bot_msg = send_all.click(
203
- bot_response,
204
- [chatbot],
205
- [chatbot],
206
- api_name="bot_response",
207
- )
208
-
209
- # Update status after bot response
210
- bot_msg.then(lambda: model_manager.get_status(), None, [status_text])
211
- bot_msg.then(lambda: gr.ClearButton(visible=True), None, [clear_button])
212
- clear_button.click(lambda: gr.ClearButton(visible=False), None, [clear_button])
213
-
214
- gr.Examples(
215
- list(examples.values()),
216
- chatbot,
217
- chatbot,
218
- example_labels=list(examples.keys()),
219
- examples_per_page=20,
220
- )
221
-
222
-
223
- def summarize_batch_results(results):
224
- summary = Counter(results)
225
- summary_str = "\n".join(f"{k}: {v}" for k, v in summary.most_common())
226
- return summary_str
227
-
228
-
229
- def run_batch_inference(files, task, progress=gr.Progress()) -> str:
230
- model = model_manager.load_model()
231
- if model is None:
232
- if model_manager.is_loading:
233
- return "🔄 Loading model... This may take a few minutes on first use. Please try again in a moment."
234
- elif model_manager.load_failed:
235
- return "❌ Model failed to load. This could be due to:\n• No internet connection\n• Insufficient disk space\n• Model repository access issues\n\nPlease check your connection and try again."
236
- else:
237
- return "Demo mode: Model not loaded. Please check the model configuration."
238
-
239
- outputs = []
240
- prompt = "<Audio><AudioHere></Audio> " + task
241
-
242
- for file in progress.tqdm(files):
243
- outputs.append(prompt_lm([file], [{"role": "user", "content": prompt}]))
244
-
245
- batch_summary: str = summarize_batch_results(outputs)
246
- report = f"Batch summary:\n{batch_summary}\n\n"
247
- return report
248
-
249
-
250
- def multi_extension_glob_mask(mask_base, *extensions):
251
- mask_ext = ["[{}]".format("".join(set(c))) for c in zip(*extensions)]
252
- if not mask_ext or len(set(len(e) for e in extensions)) > 1:
253
- mask_ext.append("*")
254
- return mask_base + "".join(mask_ext)
255
-
256
-
257
- def _batch_tab(file_selection: Literal["upload", "explorer"] = "upload"):
258
- if file_selection == "explorer":
259
- files = gr.FileExplorer(
260
- glob=multi_extension_glob_mask("**.", "mp3", "flac", "wav"),
261
- label="Select audio files",
262
- file_count="multiple",
263
  )
264
- elif file_selection == "upload":
265
- files = gr.Files(label="Uploaded files", file_types=["audio"], height=300)
266
- task = gr.Textbox(label="Task", placeholder="Enter task...", show_label=True)
267
-
268
- process_btn = gr.Button("Process")
269
- output = gr.TextArea()
270
-
271
- process_btn.click(
272
- run_batch_inference,
273
- [files, task],
274
- [output],
275
- )
276
-
277
-
278
- def to_raven_format(outputs: dict[int, str], chunk_len: int = 10) -> str:
279
- def get_line(row, start, end, annotation):
280
- return f"{row}\tSpectrogram 1\t1\t{start}\t{end}\t0\t8000\t{annotation}"
281
-
282
- raven_output = ["Selection\tView\tChannel\tBegin Time (s)\tEnd Time (s)\tLow Freq (Hz)\tHigh Freq (Hz)\tAnnotation"]
283
- current_offset = 0
284
- last_label = ""
285
- row = 1
286
-
287
- for offset, label in sorted(outputs.items()):
288
- if label != last_label and last_label:
289
- raven_output.append(get_line(row, current_offset, offset, last_label))
290
- current_offset = offset
291
- row += 1
292
- if not last_label:
293
- current_offset = offset
294
- if label != "None":
295
- last_label = label
296
  else:
297
- last_label = ""
298
- if last_label:
299
- raven_output.append(get_line(row, current_offset, current_offset + chunk_len, last_label))
300
-
301
- return "\n".join(raven_output)
302
-
303
-
304
- def _run_long_recording_inference(file, task, chunk_len: int = 10, hop_len: int = 5, progress=gr.Progress()):
305
- # Check if model is loading
306
- if model_manager.is_loading:
307
- return "🔄 Loading model... This may take a few minutes on first use. Please try again in a moment.", None
308
-
309
- # Check if model failed to load
310
- if model_manager.load_failed:
311
- return "❌ Model failed to load. This could be due to:\n• No internet connection\n• Insufficient disk space\n• Model repository access issues\n\nPlease refresh the page to try again.", None
312
-
313
- model = model_manager.load_model()
314
- if model is None:
315
- return "Demo mode: Model not loaded. Please check the model configuration.", None
316
-
317
- cuda_enabled = torch.cuda.is_available()
318
- outputs = {}
319
- offset = 0
320
-
321
- prompt = f"<Audio><AudioHere></Audio> {task}"
322
- prompt = model_manager.config.model.prompt_template.format(prompt)
323
-
324
- for batch in progress.tqdm(generate_sample_batches(file, cuda_enabled, chunk_len=chunk_len, hop_len=hop_len)):
325
- prompt_strs = [prompt] * len(batch["audio_chunk_sizes"])
326
- with torch.cuda.amp.autocast(dtype=torch.float16):
327
- llm_answers = model.generate(batch, model_manager.config.generate, prompts=prompt_strs)
328
- for answer in llm_answers:
329
- outputs[offset] = answer
330
- offset += hop_len
331
-
332
- report = f"Number of chunks: {len(outputs)}\n\n"
333
- for offset in sorted(outputs.keys()):
334
- report += f"{offset:02d}s:\t{outputs[offset]}\n"
335
-
336
- raven_output = to_raven_format(outputs, chunk_len=chunk_len)
337
- with tempfile.NamedTemporaryFile(mode="w", prefix="raven-", suffix=".txt", delete=False) as f:
338
- f.write(raven_output)
339
- raven_file = f.name
340
-
341
- return report, raven_file
342
-
343
-
344
- def _long_recording_tab():
345
- audio_input = gr.Audio(label="Upload audio file", type="filepath")
346
- task = gr.Dropdown(
347
- [
348
- "What are the common names for the species in the audio, if any?",
349
- "Caption the audio.",
350
- "Caption the audio, using the scientific name for any animal species.",
351
- "Caption the audio, using the common name for any animal species.",
352
- "What is the scientific name for the focal species in the audio?",
353
- "What is the common name for the focal species in the audio?",
354
- "What is the family of the focal species in the audio?",
355
- "What is the genus of the focal species in the audio?",
356
- "What is the taxonomic name of the focal species in the audio?",
357
- "What call types are heard from the focal species in the audio?",
358
- "What is the life stage of the focal species in the audio?",
359
- ],
360
- label="Tasks",
361
- allow_custom_value=True,
362
- )
363
- with gr.Accordion("Advanced options", open=False):
364
- hop_len = gr.Slider(1, 10, 5, label="Hop length (seconds)", step=1)
365
- chunk_len = gr.Slider(1, 10, 10, label="Chunk length (seconds)", step=1)
366
- process_btn = gr.Button("Process")
367
- output = gr.TextArea()
368
- download_raven = gr.DownloadButton("Download Raven file")
369
-
370
- process_btn.click(
371
- _run_long_recording_inference,
372
- [audio_input, task, chunk_len, hop_len],
373
- [output, download_raven],
374
- )
375
 
376
 
377
  def main(
378
  assets_dir: Path,
379
  cfg_path: str | Path,
380
  options: list[str] = [],
381
- device: str = "cuda",
382
  ):
383
  # Load configuration
384
  try:
@@ -394,7 +255,7 @@ def main(
394
  if not assets_dir.exists():
395
  print(f"Warning: Assets directory {assets_dir} does not exist")
396
  assets_dir.mkdir(exist_ok=True)
397
-
398
  # Create placeholder audio files if they don't exist
399
  laz_audio = assets_dir / "Lazuli_Bunting_yell-YELLLAZB20160625SM303143.mp3"
400
  frog_audio = assets_dir / "nri-GreenTreeFrogEvergladesNP.mp3"
@@ -411,7 +272,9 @@ def main(
411
  "Caption the audio (Green Tree Frog)": [
412
  [
413
  user_message({"path": str(frog_audio)}),
414
- user_message("Caption the audio, using the common name for any animal species."),
 
 
415
  ]
416
  ],
417
  "Caption the audio (American Robin)": [
@@ -428,17 +291,31 @@ def main(
428
  ],
429
  }
430
 
431
- with gr.Blocks(title="NatureLM-audio", theme=gr.themes.Base(primary_hue="blue", font=[gr.themes.GoogleFont("Noto Sans")])) as app:
 
 
 
 
 
432
  header = gr.HTML("""
433
  <div style="display: flex; align-items: center; gap: 12px;"><h2 style="margin: 0;">NatureLM-audio<span style="font-size: 0.55em; color: #28a745; background: #e6f4ea; padding: 2px 6px; border-radius: 4px; margin-left: 8px; display: inline-block; vertical-align: top;">BETA</span></h2></div>
434
 
435
  """)
436
-
437
  with gr.Tabs():
438
  with gr.Tab("Analyze Audio"):
439
- uploaded_audio = gr.State()
440
- with gr.Column(visible=True) as onboarding_message:
441
- gr.HTML("""
 
 
 
 
 
 
 
 
 
442
  <div style="
443
  background: transparent;
444
  border: 1px solid #e5e7eb;
@@ -476,45 +353,102 @@ def main(
476
  onmouseout="this.style.background='#3b82f6';"
477
  >View Tutorial</a>
478
  </div>
479
- """, padding=False)
 
 
480
  with gr.Column(visible=True) as upload_section:
481
- audio_input = gr.Audio(
482
  type="filepath",
483
- container=True,
484
- interactive=True,
485
- sources=['upload']
486
- )
487
  with gr.Group(visible=False) as chat:
488
- chatbot = gr.Chatbot(
489
- elem_id="chatbot",
490
- type="messages",
 
491
  render_markdown=False,
492
- feedback_options=["like", "dislike", "wrong species", "incorrect response", "other"],
493
- resizeable=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
494
  )
495
- chat_input = _multimodal_textbox_factory()
496
- send_all = gr.Button("Send all")
497
 
498
-
499
  def start_chat_interface(audio_path):
500
- return (
501
- gr.update(visible=False), # hide onboarding message
502
  gr.update(visible=True), # show upload section
503
- gr.update(visible=True), # show chat box
504
  )
505
 
506
  audio_input.change(
507
  fn=start_chat_interface,
508
  inputs=[audio_input],
509
- outputs=[onboarding_message, upload_section, chat]
510
  )
511
 
512
- chat_input.submit(add_message, [chatbot, chat_input], [chatbot, chat_input])
513
- send_all.click(bot_response, [chatbot], [chatbot])
 
 
 
514
 
 
 
 
515
 
516
  with gr.Tab("Sample Library"):
517
- gr.Markdown("## Sample Library\n\nExplore example audio files below.")
518
  gr.Examples(
519
  list(examples.values()),
520
  chatbot,
@@ -523,10 +457,10 @@ def main(
523
  examples_per_page=20,
524
  )
525
  with gr.Tab("💡 Help"):
526
- gr.Markdown("## User Guide") # to fill out
527
- gr.Markdown("## Share Feedback") # to fill out
528
- gr.Markdown("## FAQs") # to fill out
529
-
530
  app.css = """
531
  .welcome-banner {
532
  background: transparent !important;
@@ -550,7 +484,7 @@ def main(
550
  _batch_tab()
551
  with gr.Tab("Long Recording"):
552
  _long_recording_tab() """
553
-
554
  return app
555
 
556
 
@@ -559,8 +493,7 @@ app = main(
559
  assets_dir=Path("assets"),
560
  cfg_path=Path("configs/inference.yml"),
561
  options=[],
562
- device="cuda",
563
  )
564
 
565
  if __name__ == "__main__":
566
- app.launch()
 
1
+ import warnings
2
+ import numpy as np
 
3
  from pathlib import Path
4
+ from typing import Optional
5
+ from collections import Counter
6
 
7
  import gradio as gr
8
  import torch
9
+ import torchaudio
10
+ import matplotlib.pyplot as plt
11
 
12
  from NatureLM.config import Config
13
  from NatureLM.models.NatureLM import NatureLM
14
+ from NatureLM.infer import Pipeline
15
  import spaces
16
 
17
+ warnings.filterwarnings("ignore")
18
+ SAMPLE_RATE = 16000 # Default sample rate for NatureLM-audio
19
+
20
+
21
+ def get_spectrogram(audio: torch.Tensor) -> plt.Figure:
22
+ """Generate a spectrogram from the audio tensor."""
23
+ spectrogram = torchaudio.transforms.Spectrogram(n_fft=1024)(audio)
24
+ spectrogram = spectrogram.numpy()[0].squeeze()
25
+ # Convert to matplotlib figure with imshow
26
+ fig, ax = plt.subplots(figsize=(13, 5))
27
+ ax.imshow(np.log(spectrogram + 1e-3), aspect="auto", origin="lower", cmap="viridis")
28
+ ax.set_title("Spectrogram")
29
+ ax.set_xlabel("Time")
30
+ # Set x ticks to reflect 0 to audio duration seconds
31
+ if audio.dim() > 1:
32
+ duration = audio.size(1) / SAMPLE_RATE
33
+ else:
34
+ duration = audio.size(0) / SAMPLE_RATE
35
+
36
+ ax.set_xticks([0, spectrogram.shape[1]])
37
+ ax.set_xticklabels(["0s", f"{duration:.2f}s"])
38
+ ax.set_ylabel("Frequency")
39
+ # Set y ticks to reflect 0 to nyquist frequency (sample_rate/2)
40
+ nyquist_freq = SAMPLE_RATE / 2
41
+ ax.set_yticks(
42
+ [
43
+ 0,
44
+ spectrogram.shape[0] // 4,
45
+ spectrogram.shape[0] // 2,
46
+ 3 * spectrogram.shape[0] // 4,
47
+ spectrogram.shape[0] - 1,
48
+ ]
49
+ )
50
+ ax.set_yticklabels(
51
+ [
52
+ "0 Hz",
53
+ f"{nyquist_freq / 4:.0f} Hz",
54
+ f"{nyquist_freq / 2:.0f} Hz",
55
+ f"{3 * nyquist_freq / 4:.0f} Hz",
56
+ f"{nyquist_freq:.0f} Hz",
57
+ ]
58
+ )
59
+ fig.tight_layout()
60
+
61
+ return fig
62
+
63
 
64
  class ModelManager:
65
  """Manages model loading and state"""
66
+
67
  def __init__(self):
68
  self.model: Optional[NatureLM] = None
69
  self.config: Optional[Config] = None
70
  self.is_loaded = False
71
  self.is_loading = False
72
  self.load_failed = False
73
+
74
  def check_availability(self) -> tuple[bool, str]:
75
  """Check if the model is available for download"""
76
  try:
77
  from huggingface_hub import model_info
78
+
79
  info = model_info("EarthSpeciesProject/NatureLM-audio")
80
  return True, "Model is available"
81
  except Exception as e:
82
  return False, f"Model not available: {str(e)}"
83
+
84
  def reset_state(self):
85
  """Reset the model loading state to allow retrying after a failure"""
86
  self.model = None
 
88
  self.is_loading = False
89
  self.load_failed = False
90
  return self.get_status()
91
+
92
  def get_status(self) -> str:
93
  """Get the current model loading status"""
94
  if self.is_loaded:
 
99
  return "❌ Model failed to load. Please check the configuration."
100
  else:
101
  return "⏳ Ready to load model on first use"
102
+
103
  def load_model(self) -> Optional[NatureLM]:
104
  """Load the model if needed"""
105
  if self.is_loaded:
106
  return self.model
107
+
108
  if self.is_loading or self.load_failed:
109
  return None
110
+
111
  try:
112
  self.is_loading = True
113
  print("Loading model...")
114
+
115
  # Check if model is available first
116
  available, message = self.check_availability()
117
  if not available:
118
  raise Exception(f"Model not available: {message}")
119
+
120
  model = NatureLM.from_pretrained("EarthSpeciesProject/NatureLM-audio")
121
+ model.to("cpu")
122
  model.eval()
123
+
124
+ pipe = Pipeline(model)
125
+ self.model = pipe
126
  self.is_loaded = True
127
  self.is_loading = False
128
  print("Model loaded successfully!")
129
+ return pipe
130
+
131
  except Exception as e:
132
  print(f"Error loading model: {e}")
133
  self.is_loading = False
 
138
  # Global model manager instance
139
  model_manager = ModelManager()
140
 
141
+
142
+ def take_majority_vote(results: list[list[dict]]) -> list[str]:
143
+ """For each audio file, take the majority vote of the labels across all windows"""
144
+ outputs = []
145
+ for result in results:
146
+ predictions = [window["prediction"] for window in result]
147
+ if not predictions:
148
+ continue
149
+ # Count occurrences of each label
150
+ counts = Counter(predictions)
151
+ # Find the most common label
152
+ most_common_label, _ = counts.most_common(1)[0]
153
+ outputs.append(most_common_label)
154
+
155
+ return outputs
156
+
157
+
158
  @spaces.GPU
159
+ def prompt_lm(
160
+ audios: list[str],
161
+ queries: list[str] | str,
162
+ window_length_seconds: float = 10.0,
163
+ hop_length_seconds: float = 10.0,
164
+ progress=gr.Progress(),
165
+ ) -> list[str]:
166
+ """Generate response using the model
167
+
168
+ Args:
169
+ audios (list[str]): List of audio file paths
170
+ queries (list[str] | str): Query or list of queries to process
171
+ window_length_seconds (float): Length of the window for processing audio
172
+ hop_length_seconds (float): Hop length for processing audio
173
+
174
+ Returns:
175
+ list[str]: List of generated responses for each audio-query pair
176
+ """
177
  model = model_manager.load_model()
178
+
179
  if model is None:
180
  if model_manager.is_loading:
181
  return "🔄 Loading model... This may take a few minutes on first use. Please try again in a moment."
 
183
  return "❌ Model failed to load. This could be due to:\n• No internet connection\n• Insufficient disk space\n• Model repository access issues\n\nPlease check your connection and try again using the retry button."
184
  else:
185
  return "Demo mode: Model not loaded. Please check the model configuration."
186
+
187
+ results: list[list[dict]] = model(
188
+ audios,
189
+ queries,
190
+ window_length_seconds=window_length_seconds,
191
+ hop_length_seconds=hop_length_seconds,
192
+ input_sample_rate=None,
193
+ progress_bar=progress,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  )
195
+ return results
196
 
197
 
198
  def user_message(content):
199
  return {"role": "user", "content": content}
200
 
201
 
202
+ def add_message_and_get_response(
203
+ chatbot_history: list[dict], audio_input: str, chat_input: str
204
+ ) -> tuple[list[dict], str]:
205
+ """Add user message to chat and get model response"""
206
+ # Load audio with torchaudio and compute spectrogram
207
+ audio_tensor, sample_rate = torchaudio.load(audio_input)
208
+ duration = audio_tensor.size(1) / sample_rate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
+ spectrogram_fig = get_spectrogram(audio_tensor)
211
+ # Add gr.Plot to chatbot history
212
+ chatbot_history.append(
213
+ {"role": "user", "content": gr.Plot(spectrogram_fig, label="Spectrogram")}
 
 
 
 
 
 
 
 
 
 
 
 
214
  )
215
+ # Get response
216
+ try:
217
+ response = prompt_lm(
218
+ audios=[audio_input],
219
+ queries=[chat_input],
220
+ window_length_seconds=duration,
221
+ hop_length_seconds=duration,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  )
223
+ # get first item
224
+ if isinstance(response, list) and len(response) > 0:
225
+ response = response[0][0]["prediction"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  else:
227
+ response = "No response generated."
228
+ except Exception as e:
229
+ print(f"Error generating response: {e}")
230
+ response = "Error generating response. Please try again."
231
+
232
+ # Add user message to chat history
233
+ chatbot_history.append({"role": "user", "content": "Q: " + chat_input})
234
+ # Add model response to chat history
235
+ chatbot_history.append({"role": "assistant", "content": response})
236
+ return chatbot_history, ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
 
238
 
239
  def main(
240
  assets_dir: Path,
241
  cfg_path: str | Path,
242
  options: list[str] = [],
 
243
  ):
244
  # Load configuration
245
  try:
 
255
  if not assets_dir.exists():
256
  print(f"Warning: Assets directory {assets_dir} does not exist")
257
  assets_dir.mkdir(exist_ok=True)
258
+
259
  # Create placeholder audio files if they don't exist
260
  laz_audio = assets_dir / "Lazuli_Bunting_yell-YELLLAZB20160625SM303143.mp3"
261
  frog_audio = assets_dir / "nri-GreenTreeFrogEvergladesNP.mp3"
 
272
  "Caption the audio (Green Tree Frog)": [
273
  [
274
  user_message({"path": str(frog_audio)}),
275
+ user_message(
276
+ "Caption the audio, using the common name for any animal species."
277
+ ),
278
  ]
279
  ],
280
  "Caption the audio (American Robin)": [
 
291
  ],
292
  }
293
 
294
+ with gr.Blocks(
295
+ title="NatureLM-audio",
296
+ theme=gr.themes.Base(
297
+ primary_hue="blue", font=[gr.themes.GoogleFont("Noto Sans")]
298
+ ),
299
+ ) as app:
300
  header = gr.HTML("""
301
  <div style="display: flex; align-items: center; gap: 12px;"><h2 style="margin: 0;">NatureLM-audio<span style="font-size: 0.55em; color: #28a745; background: #e6f4ea; padding: 2px 6px; border-radius: 4px; margin-left: 8px; display: inline-block; vertical-align: top;">BETA</span></h2></div>
302
 
303
  """)
304
+
305
  with gr.Tabs():
306
  with gr.Tab("Analyze Audio"):
307
+ uploaded_audio = gr.State()
308
+ # Status indicator
309
+ # status_text = gr.Textbox(
310
+ # value=model_manager.get_status(),
311
+ # label="Model Status",
312
+ # interactive=False,
313
+ # visible=True,
314
+ # )
315
+
316
+ with gr.Column(visible=True) as onboarding_message:
317
+ gr.HTML(
318
+ """
319
  <div style="
320
  background: transparent;
321
  border: 1px solid #e5e7eb;
 
353
  onmouseout="this.style.background='#3b82f6';"
354
  >View Tutorial</a>
355
  </div>
356
+ """,
357
+ padding=False,
358
+ )
359
  with gr.Column(visible=True) as upload_section:
360
+ audio_input = gr.Audio(
361
  type="filepath",
362
+ container=True,
363
+ interactive=True,
364
+ sources=["upload"],
365
+ )
366
  with gr.Group(visible=False) as chat:
367
+ chatbot = gr.Chatbot(
368
+ elem_id="chatbot",
369
+ type="messages",
370
+ label="Chat",
371
  render_markdown=False,
372
+ feedback_options=[
373
+ "like",
374
+ "dislike",
375
+ "wrong species",
376
+ "incorrect response",
377
+ "other",
378
+ ],
379
+ resizeable=True,
380
+ )
381
+ gr.Markdown("### Your Query")
382
+ task_dropdown = gr.Dropdown(
383
+ [
384
+ "What are the common names for the species in the audio, if any?",
385
+ "Caption the audio.",
386
+ "Caption the audio, using the scientific name for any animal species.",
387
+ "Caption the audio, using the common name for any animal species.",
388
+ "What is the scientific name for the focal species in the audio?",
389
+ "What is the common name for the focal species in the audio?",
390
+ "What is the family of the focal species in the audio?",
391
+ "What is the genus of the focal species in the audio?",
392
+ "What is the taxonomic name of the focal species in the audio?",
393
+ "What call types are heard from the focal species in the audio?",
394
+ "What is the life stage of the focal species in the audio?",
395
+ ],
396
+ label="Pre-configured Tasks",
397
+ allow_custom_value=True,
398
+ info="Select a task or enter a custom query below",
399
+ )
400
+ chat_input = gr.Textbox(
401
+ placeholder="e.g. 'Caption this audio'...",
402
+ type="text",
403
+ label="Query",
404
+ lines=2,
405
+ show_label=True,
406
+ container=False,
407
+ submit_btn="Send",
408
+ elem_id="chat-input",
409
+ )
410
+
411
+ # if task_dropdown is selected, set chat_input to that value
412
+ def set_query(task):
413
+ if task:
414
+ return gr.update(value=task)
415
+ return gr.update(value="")
416
+
417
+ task_dropdown.change(
418
+ fn=set_query,
419
+ inputs=[task_dropdown],
420
+ outputs=[chat_input],
421
+ )
422
+
423
+ clear_button = gr.ClearButton(
424
+ components=[chatbot, chat_input, audio_input], visible=False
425
  )
 
 
426
 
 
427
  def start_chat_interface(audio_path):
428
+ return (
429
+ gr.update(visible=False), # hide onboarding message
430
  gr.update(visible=True), # show upload section
431
+ gr.update(visible=True), # show chat box
432
  )
433
 
434
  audio_input.change(
435
  fn=start_chat_interface,
436
  inputs=[audio_input],
437
+ outputs=[onboarding_message, upload_section, chat],
438
  )
439
 
440
+ chat_input.submit(
441
+ add_message_and_get_response,
442
+ inputs=[chatbot, audio_input, chat_input],
443
+ outputs=[chatbot, chat_input],
444
+ ).then(lambda: gr.ClearButton(visible=True), None, [clear_button])
445
 
446
+ clear_button.click(
447
+ lambda: gr.ClearButton(visible=False), None, [clear_button]
448
+ )
449
 
450
  with gr.Tab("Sample Library"):
451
+ gr.Markdown("## Sample Library\n\nExplore example audio files below.")
452
  gr.Examples(
453
  list(examples.values()),
454
  chatbot,
 
457
  examples_per_page=20,
458
  )
459
  with gr.Tab("💡 Help"):
460
+ gr.Markdown("## User Guide") # to fill out
461
+ gr.Markdown("## Share Feedback") # to fill out
462
+ gr.Markdown("## FAQs") # to fill out
463
+
464
  app.css = """
465
  .welcome-banner {
466
  background: transparent !important;
 
484
  _batch_tab()
485
  with gr.Tab("Long Recording"):
486
  _long_recording_tab() """
487
+
488
  return app
489
 
490
 
 
493
  assets_dir=Path("assets"),
494
  cfg_path=Path("configs/inference.yml"),
495
  options=[],
 
496
  )
497
 
498
  if __name__ == "__main__":
499
+ app.launch()
requirements.txt CHANGED
@@ -1,31 +1,19 @@
1
- torch>=2.2.2
2
- torchaudio>=2.2.2
3
- torchvision>=0.17.2
4
- transformers[sentencepiece]>=4.44.2
5
- datasets>=2.20.0
6
- cloudpathlib[gs]>=0.20.0
7
- einops>=0.8.0
8
- gradio>=5.10.0
9
- google-cloud-aiplatform>=1.76.0
10
- Levenshtein>=0.25.1
11
- librosa>=0.9.2
12
- memoization>=0.4.0
13
- mir-eval>=0.7
14
- numpy>=1.26.4
15
- pandas>=1.4.3
16
- peft>=0.11.1
17
- plumbum>=1.7.2
18
- pydantic-settings>=2.7.1
19
- pydantic>=2.7.4
20
- pydub>=0.25.1
21
- pyyaml>=6.0
22
- resampy>=0.3.1
23
- scipy>=1.14.0
24
- soundfile>=0.12.1
25
- tensorboard>=2.18.0
26
- tensorboardX>=2.6.2.2
27
- spaces>=0.39.0
28
- tqdm>=4.66.4
29
- wandb>=0.17.3
30
- click>=8.1.7
31
- git+https://github.com/earthspecies/beans-zero.git
 
1
+ click>=8.2.1
2
+ einops>=0.8.1
3
+ gradio>=5.42.0
4
+ librosa>=0.11.0
5
+ pandas>=2.3.1
6
+ peft>=0.17.0
7
+ plumbum>=1.9.0
8
+ pydantic>=2.11.7
9
+ pydantic-settings>=2.10.1
10
+ pyyaml>=6.0.2
11
+ resampy>=0.4.3
12
+ scipy>=1.15.3
13
+ soundfile>=0.13.1
14
+ spaces>=0.40.0
15
+ torch>=2.8.0
16
+ torchaudio>=2.8.0
17
+ tqdm>=4.67.1
18
+ transformers[sentencepiece]>=4.55.2
19
+ matplotlib>=3.10.5