Automatic Speech Recognition
Transformers
Safetensors
English
asr_model
feature-extraction
asr
speech-recognition
audio
qwen
glm-asr
custom_code
Instructions to use mazesmazes/tiny-audio with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use mazesmazes/tiny-audio with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("automatic-speech-recognition", model="mazesmazes/tiny-audio", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("mazesmazes/tiny-audio", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """Audio projector modules for bridging encoder and decoder embeddings. | |
| This module contains all projector architectures: | |
| - MLPAudioProjector: Simple 2-layer MLP with frame stacking downsampling | |
| - MOSAProjector: MOSA-style dense mixture of experts | |
| - SharedMoEAudioProjector: Shared expert + sparse routed experts | |
| - QFormerAudioProjector: BLIP-2 QFormer with learnable queries (Granite-style) | |
| """ | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F # noqa: N812 | |
| from transformers import AutoModel, Blip2QFormerConfig | |
| from transformers.models.llama.modeling_llama import LlamaRMSNorm | |
| # ============================================================================= | |
| # MLP Projector | |
| # ============================================================================= | |
| class MLPAudioProjector(nn.Module): | |
| """2-layer MLP projector with frame-stacking downsampling (matches GLM-ASR).""" | |
| def __init__(self, config): | |
| """Initialize MLP projector. | |
| Args: | |
| config: ASRConfig with encoder_dim, llm_dim, projector_pool_stride | |
| """ | |
| super().__init__() | |
| encoder_dim = getattr(config, "encoder_dim", 768) | |
| llm_dim = getattr(config, "llm_dim", 2048) | |
| self.k = getattr(config, "projector_pool_stride", 4) | |
| # Frame stacking: concat k adjacent frames then project | |
| in_dim = encoder_dim * self.k | |
| # Hidden dim defaults to llm_dim, can be overridden via config | |
| hidden_dim = getattr(config, "projector_hidden_dim", None) or llm_dim | |
| self.linear_1 = nn.Linear(in_dim, hidden_dim, bias=False) | |
| self.norm = LlamaRMSNorm(hidden_dim, eps=1e-6) | |
| self.act = nn.GELU() | |
| self.linear_2 = nn.Linear(hidden_dim, llm_dim, bias=False) | |
| # Output norm aligns the projector's RMS with the LM's embed_tokens | |
| # distribution. Without it, linear_2's Kaiming-uniform init produces | |
| # outputs ~30× quieter than embed rows, which saturates softmax at | |
| # audio positions and starves them of gradient. | |
| self.norm_2 = LlamaRMSNorm(llm_dim, eps=1e-6) | |
| def get_output_length(self, input_length: int) -> int: | |
| """Calculate output sequence length given input length (matches GLM-ASR).""" | |
| # GLM-ASR formula: (L - merge_factor) // merge_factor + 1 | |
| return (input_length - self.k) // self.k + 1 | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """Project audio features to LLM embedding space. | |
| Args: | |
| x: Audio encoder output of shape [batch, seq_len, encoder_dim] | |
| Returns: | |
| Projected features of shape [batch, (seq_len - k) // k + 1, llm_dim] | |
| """ | |
| x = _frame_stack(x, self.k) | |
| x = self.linear_1(x) | |
| x = self.norm(x) | |
| x = self.act(x) | |
| x = self.linear_2(x) | |
| return self.norm_2(x) | |
| # ============================================================================= | |
| # MoE Projector (MOSA-style) | |
| # ============================================================================= | |
| def _frame_stack(x: torch.Tensor, k: int) -> torch.Tensor: | |
| """Stack k adjacent frames along the feature dim. | |
| Truncates trailing frames that don't fill a complete k-frame window, | |
| matching GLM-ASR's `(seq_len - k) // k + 1` formula. | |
| """ | |
| batch, seq, dim = x.shape | |
| out_len = (seq - k) // k + 1 | |
| return x[:, : out_len * k, :].reshape(batch, out_len, dim * k) | |
| class SimpleAdapter(nn.Module): | |
| """Simple 2-layer GELU adapter (from MOSA paper).""" | |
| def __init__(self, input_dim: int, hidden_dim: int, output_dim: int): | |
| super().__init__() | |
| self.fc1 = nn.Linear(input_dim, hidden_dim) | |
| self.act = nn.GELU() | |
| self.fc2 = nn.Linear(hidden_dim, output_dim) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.fc2(self.act(self.fc1(x))) | |
| class MOSAProjector(nn.Module): | |
| """MOSA-Base projector: simple 2-layer ReLU router with 4 simple adapters. | |
| Based on "MOSA: Mixtures of Simple Adapters" (arXiv:2508.18998). | |
| Uses softmax gating over all experts (dense MoE) with only cross-entropy loss. | |
| Uses Conv1d for downsampling (2 layers, stride 2 each = 4x total). | |
| """ | |
| ADAPTER_HIDDEN_DIM = 4096 | |
| ROUTER_HIDDEN_DIM = 512 | |
| CONV_KERNEL = 3 | |
| CONV_STRIDE = 2 | |
| CONV_PADDING = 1 | |
| def __init__(self, config): | |
| """Initialize MOSA projector. | |
| Args: | |
| config: ASRConfig with encoder_dim, llm_dim, num_experts | |
| """ | |
| super().__init__() | |
| self.encoder_dim = getattr(config, "encoder_dim", None) or 1280 | |
| self.llm_dim = getattr(config, "llm_dim", None) or 2048 | |
| self.num_experts = getattr(config, "num_experts", None) or 4 # MOSA-Base uses 4 | |
| conv_kwargs = { | |
| "kernel_size": self.CONV_KERNEL, | |
| "stride": self.CONV_STRIDE, | |
| "padding": self.CONV_PADDING, | |
| } | |
| self.downsampler = nn.Sequential( | |
| nn.Conv1d(self.encoder_dim, self.encoder_dim, **conv_kwargs), | |
| nn.GELU(), | |
| nn.Conv1d(self.encoder_dim, self.llm_dim, **conv_kwargs), | |
| nn.GELU(), | |
| ) | |
| self.router = nn.Sequential( | |
| nn.Linear(self.llm_dim, self.ROUTER_HIDDEN_DIM), | |
| nn.ReLU(), | |
| nn.Linear(self.ROUTER_HIDDEN_DIM, self.num_experts), | |
| ) | |
| self.experts = nn.ModuleList( | |
| [ | |
| SimpleAdapter(self.llm_dim, self.ADAPTER_HIDDEN_DIM, self.llm_dim) | |
| for _ in range(self.num_experts) | |
| ] | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """Project audio features using mixture of experts. | |
| Args: | |
| x: Audio encoder output of shape [batch, seq_len, encoder_dim] | |
| Returns: | |
| Projected features of shape [batch, out_len, llm_dim] | |
| """ | |
| x = self.downsampler(x.transpose(1, 2)).transpose(1, 2) | |
| routing_weights = F.softmax(self.router(x), dim=-1) # (B, out_len, num_experts) | |
| # Accumulate weighted expert outputs without materializing all experts at once. | |
| output = self.experts[0](x) * routing_weights[..., 0:1] | |
| for i, expert in enumerate(self.experts[1:], start=1): | |
| output = output + expert(x) * routing_weights[..., i : i + 1] | |
| return output | |
| def get_output_length(self, input_length: int) -> int: | |
| """Calculate output sequence length after Conv1d downsampling (4x reduction).""" | |
| length = input_length | |
| for _ in range(2): | |
| length = (length + 2 * self.CONV_PADDING - self.CONV_KERNEL) // self.CONV_STRIDE + 1 | |
| return length | |
| # ============================================================================= | |
| # MoE Projector (Pure PyTorch with Shared Expert) | |
| # ============================================================================= | |
| class MoEAudioProjector(nn.Module): | |
| """MoE projector with shared expert (DeepSeek-style), pure PyTorch implementation. | |
| Uses 4 sparse experts with top-2 routing plus a shared expert that processes all tokens. | |
| No external dependencies (megablocks removed). | |
| Architecture matches main branch: norm → experts(in_dim → hidden → out_dim) | |
| """ | |
| def __init__(self, config): | |
| """Initialize MoE projector. | |
| Args: | |
| config: ASRConfig with encoder_dim, llm_dim, num_experts, num_experts_per_tok | |
| """ | |
| super().__init__() | |
| self.k = getattr(config, "projector_pool_stride", 4) | |
| self.aux_coef = getattr(config, "router_aux_loss_coef", 0.01) | |
| # Stability coefficients | |
| self.router_z_loss_coef = getattr( | |
| config, "router_z_loss_coef", 1e-4 | |
| ) # Prevents logit explosion | |
| self.router_jitter_noise = getattr( | |
| config, "router_jitter_noise", 0.01 | |
| ) # Prevents expert collapse | |
| in_dim = config.encoder_dim * self.k | |
| out_dim = config.llm_dim | |
| # Expert hidden dim (default = output dim) | |
| hidden_dim = getattr(config, "projector_hidden_dim", None) or out_dim | |
| # Number of experts and top-k selection | |
| self.num_experts = getattr(config, "num_experts", 4) | |
| self.top_k = getattr(config, "num_experts_per_tok", 2) | |
| # A. Normalize stacked input (like main branch SharedMoEBlock) | |
| self.norm = LlamaRMSNorm(in_dim, eps=1e-6) | |
| # B. Router (operates on stacked input) | |
| self.router = nn.Linear(in_dim, self.num_experts, bias=False) | |
| # C. Experts: simple 2-layer MLP (same as MLPAudioProjector) | |
| self.experts = nn.ModuleList( | |
| [SimpleAdapter(in_dim, hidden_dim, out_dim) for _ in range(self.num_experts)] | |
| ) | |
| # D. Shared Expert (same architecture) | |
| self.shared_expert = SimpleAdapter(in_dim, hidden_dim, out_dim) | |
| # E. Initialize weights for stable training | |
| self._init_weights() | |
| self.last_aux_loss = torch.tensor(0.0) | |
| def _init_weights(self): | |
| """Initialize weights for stable training start.""" | |
| with torch.no_grad(): | |
| # Router: small weights -> uniform probability | |
| nn.init.normal_(self.router.weight, mean=0.0, std=0.02) | |
| # Experts: xavier for fc1, small for fc2 (output) | |
| for expert in [self.shared_expert, *self.experts]: | |
| nn.init.xavier_uniform_(expert.fc1.weight) | |
| nn.init.normal_(expert.fc2.weight, mean=0.0, std=0.01) # Small init | |
| def get_output_length(self, input_length: int) -> int: | |
| """Calculate output sequence length given input length (matches MLP projector).""" | |
| return (input_length - self.k) // self.k + 1 | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """Project audio features using shared + sparse MoE. | |
| Args: | |
| x: Audio encoder output of shape [batch, seq_len, encoder_dim] | |
| Returns: | |
| Projected features of shape [batch, out_len, llm_dim] | |
| """ | |
| x = _frame_stack(x, self.k) | |
| batch, out_len, _ = x.shape | |
| # Normalize stacked input (like main branch SharedMoEBlock) | |
| x = self.norm(x) | |
| flat_x = x.view(-1, x.size(-1)) # [tokens, in_dim] | |
| # 3. Shared Expert (compute first, creates output tensor) | |
| output = self.shared_expert(flat_x) | |
| # 4. Sparse Experts (in-place add to shared output) | |
| self.last_aux_loss = self._forward_sparse(flat_x, output) | |
| return output.view(batch, out_len, -1) | |
| def _forward_sparse(self, x: torch.Tensor, output: torch.Tensor) -> torch.Tensor: | |
| """Stability-hardened sparse expert dispatch (in-place add to output). | |
| Args: | |
| x: Flattened input of shape [tokens, dim] | |
| output: Output tensor to add sparse expert results into (in-place) | |
| Returns: | |
| Auxiliary loss tensor | |
| """ | |
| # A. Router Logic with Jitter | |
| logits = self.router(x) | |
| if self.training and self.router_jitter_noise > 0: | |
| # Jitter: multiply by uniform noise (1-eps, 1+eps) to shake decision boundary | |
| # Prevents router from getting stuck on one expert early in training | |
| noise = torch.empty_like(logits).uniform_( | |
| 1.0 - self.router_jitter_noise, 1.0 + self.router_jitter_noise | |
| ) | |
| logits = logits * noise | |
| # Force float32 for softmax (bf16/fp16 exponentials can overflow) | |
| probs = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(x) | |
| # B. Top-K Selection | |
| top_k_weights, top_k_indices = torch.topk(probs, self.top_k, dim=-1) | |
| # Normalize weights so they sum to 1.0 | |
| top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-6) | |
| # C. Aux Loss + Z-Loss | |
| aux_loss = torch.tensor(0.0, device=x.device) | |
| if self.training: | |
| # Load balancing loss (batch-size invariant) | |
| prob_per_expert = probs.mean(0) # [num_experts] | |
| target = 1.0 / self.num_experts | |
| balance_loss = ( | |
| self.aux_coef * ((prob_per_expert - target) ** 2).mean() * self.num_experts | |
| ) | |
| # Z-loss: penalty on large logits to prevent softmax saturation | |
| z_loss = self.router_z_loss_coef * torch.logsumexp(logits, dim=-1).pow(2).mean() | |
| aux_loss = balance_loss + z_loss | |
| # D. Dispatch Loop (in-place add to output) | |
| for i, expert in enumerate(self.experts): | |
| # Create boolean mask for tokens that selected Expert 'i' | |
| mask = top_k_indices == i | |
| if mask.any(): | |
| # token_idx = which tokens, k_idx = 1st or 2nd choice | |
| token_idx, k_idx = torch.where(mask) | |
| # Gather inputs and compute | |
| expert_input = x[token_idx] | |
| expert_output = expert(expert_input) | |
| # Apply routing weight | |
| weight = top_k_weights[token_idx, k_idx].unsqueeze(-1) | |
| weighted_output = (expert_output * weight).type_as(output) | |
| # Scatter back in-place (index_add_ is atomic and deterministic) | |
| output.index_add_(0, token_idx, weighted_output) | |
| return aux_loss | |
| def get_aux_loss(self) -> torch.Tensor: | |
| """Return auxiliary load balancing loss.""" | |
| return self.last_aux_loss | |
| # ============================================================================= | |
| # QFormer Projector (Granite-style) | |
| # ============================================================================= | |
| class QFormerAudioProjector(nn.Module): | |
| """ | |
| BLIP-2 QFormer projector with learnable queries. | |
| Based on GraniteSpeechEncoderProjector - uses a QFormer model with learnable | |
| query embeddings to compress and project audio encoder outputs. The audio | |
| sequence is processed in windows and downsampled via cross-attention. | |
| """ | |
| def __init__(self, config): | |
| """Initialize QFormer projector. | |
| Args: | |
| config: ASRConfig with encoder_dim, llm_dim, qformer_* settings | |
| """ | |
| super().__init__() | |
| encoder_dim = config.encoder_dim | |
| llm_dim = config.llm_dim | |
| # Window and downsampling parameters (Granite defaults: window=15, downsample=5) | |
| self.window_size = getattr(config, "qformer_window_size", 15) | |
| self.downsample_rate = getattr(config, "downsample_rate", 5) | |
| self.num_queries = self.window_size // self.downsample_rate | |
| # QFormer hidden size (matches encoder for cross-attention) | |
| qformer_hidden = getattr(config, "qformer_hidden_size", None) or encoder_dim | |
| qformer_num_layers = getattr(config, "qformer_num_layers", 2) | |
| qformer_num_heads = getattr(config, "qformer_num_heads", 16) | |
| qformer_intermediate = getattr(config, "qformer_intermediate_size", None) or ( | |
| qformer_hidden * 4 | |
| ) | |
| # Learnable query embeddings (Granite uses std=1.0) | |
| self.query = nn.Parameter(torch.zeros(1, self.num_queries, qformer_hidden)) | |
| self.query.data.normal_(mean=0.0, std=1.0) | |
| # Optional projection if encoder dim != qformer hidden | |
| if encoder_dim != qformer_hidden: | |
| self.encoder_proj = nn.Linear(encoder_dim, qformer_hidden, bias=False) | |
| else: | |
| self.encoder_proj = None | |
| # Configure QFormer to match Granite's exact config | |
| qformer_config = Blip2QFormerConfig( | |
| hidden_size=qformer_hidden, | |
| num_hidden_layers=qformer_num_layers, | |
| num_attention_heads=qformer_num_heads, | |
| intermediate_size=qformer_intermediate, | |
| encoder_hidden_size=qformer_hidden, | |
| cross_attention_frequency=1, | |
| # Granite-specific settings | |
| hidden_act="gelu", | |
| attention_probs_dropout_prob=0.1, | |
| hidden_dropout_prob=0.1, | |
| layer_norm_eps=1e-12, | |
| initializer_range=0.02, | |
| ) | |
| self.qformer = AutoModel.from_config(qformer_config) | |
| # Final projection to LLM dimension (Granite uses bias=True) | |
| self.linear = nn.Linear(qformer_hidden, llm_dim) | |
| def get_output_length(self, input_length): | |
| """Calculate output sequence length given input length. | |
| Accepts either Python ints or torch tensors; uses ceiling division so | |
| the formula is identical for both — math.ceil would block tensors. | |
| """ | |
| nblocks = (input_length + self.window_size - 1) // self.window_size | |
| return nblocks * self.num_queries | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Args: | |
| hidden_states: [batch_size, seq_len, encoder_dim] | |
| Returns: | |
| projected: [batch_size, num_output_tokens, llm_dim] | |
| """ | |
| batch_size, seq_len, dim = hidden_states.size() | |
| # Ensure float dtype for QFormer | |
| target_dtype = self.query.dtype | |
| if hidden_states.dtype != target_dtype: | |
| hidden_states = hidden_states.to(target_dtype) | |
| # Optional encoder projection | |
| if self.encoder_proj is not None: | |
| hidden_states = self.encoder_proj(hidden_states) | |
| # Compute number of windows and pad to fit | |
| nblocks = math.ceil(seq_len / self.window_size) | |
| pad = nblocks * self.window_size - seq_len | |
| if pad > 0: | |
| hidden_states = F.pad(hidden_states, (0, 0, 0, pad), "constant", 0) | |
| # Reshape to process each window: [batch*nblocks, window_size, dim] | |
| effective_batch = batch_size * nblocks | |
| hidden_states = hidden_states.view(effective_batch, self.window_size, -1) | |
| # Expand queries to match batch size | |
| query_embeds = self.query.expand(effective_batch, -1, -1) | |
| # QFormer cross-attention | |
| query_output = self.qformer( | |
| query_embeds=query_embeds, | |
| encoder_hidden_states=hidden_states, | |
| return_dict=True, | |
| ) | |
| # Reshape back: [batch, nblocks * num_queries, hidden] | |
| output_tokens = nblocks * self.num_queries | |
| query_proj = query_output.last_hidden_state.view(batch_size, output_tokens, -1) | |
| # Project to LLM dimension | |
| return self.linear(query_proj) | |
| # ============================================================================= | |
| # Projector Registry | |
| # ============================================================================= | |
| PROJECTOR_CLASSES = { | |
| "mlp": MLPAudioProjector, | |
| "mosa": MOSAProjector, | |
| "moe": MoEAudioProjector, | |
| "qformer": QFormerAudioProjector, | |
| } | |