https-huggingface-co-spaces-jujutechnology-ebook2audiobook
/
src
/chatterbox
/models
/tokenizers
/tokenizer.py
| import logging | |
| import torch | |
| from tokenizers import Tokenizer | |
| # Special tokens | |
| SOT = "[START]" | |
| EOT = "[STOP]" | |
| UNK = "[UNK]" | |
| SPACE = "[SPACE]" | |
| SPECIAL_TOKENS = [SOT, EOT, UNK, SPACE, "[PAD]", "[SEP]", "[CLS]", "[MASK]"] | |
| logger = logging.getLogger(__name__) | |
| class EnTokenizer: | |
| def __init__(self, vocab_file_path): | |
| self.tokenizer: Tokenizer = Tokenizer.from_file(vocab_file_path) | |
| self.check_vocabset_sot_eot() | |
| def check_vocabset_sot_eot(self): | |
| voc = self.tokenizer.get_vocab() | |
| assert SOT in voc | |
| assert EOT in voc | |
| def text_to_tokens(self, text: str): | |
| text_tokens = self.encode(text) | |
| text_tokens = torch.IntTensor(text_tokens).unsqueeze(0) | |
| return text_tokens | |
| def encode( self, txt: str, verbose=False): | |
| """ | |
| clean_text > (append `lang_id`) > replace SPACE > encode text using Tokenizer | |
| """ | |
| txt = txt.replace(' ', SPACE) | |
| code = self.tokenizer.encode(txt) | |
| ids = code.ids | |
| return ids | |
| def decode(self, seq): | |
| if isinstance(seq, torch.Tensor): | |
| seq = seq.cpu().numpy() | |
| txt: str = self.tokenizer.decode(seq, | |
| skip_special_tokens=False) | |
| txt = txt.replace(' ', '') | |
| txt = txt.replace(SPACE, ' ') | |
| txt = txt.replace(EOT, '') | |
| txt = txt.replace(UNK, '') | |
| return txt | |