Safetensors
llama

Write music scores with llama

Try the model online: https://huggingface.co/spaces/dx2102/llama-midi

This model is finetuned from the Llama-3.2-1B language model.

It learns to write MIDI music scores with a text representation.

Optionally, the score title can also be used as a text prompt.

To use this model, you can simply take existing code and replace meta-llama/Llama-3.2-1B with dx2102/llama-midi.

import torch
from transformers import pipeline

pipe = pipeline(
    "text-generation", 
    model="dx2102/llama-midi", 
    torch_dtype=torch.bfloat16, 
    device="cuda", # cuda/mps/cpu
)

txt = pipe(
'''
Bach
pitch duration wait velocity instrument
'''.strip(),
    max_new_tokens=10,
    temperature=1.0,
    top_p=1.0,
)[0]['generated_text']
print(txt)

To convert the text representation back to a midi file, try this:

# install this midi library
pip install symusic

symusic is a fast C++/Python library for efficient MIDI manipulation.

import symusic

# For example
txt = '''pitch duration wait velocity instrument

71 1310 0 20 0
48 330 350 20 0
55 330 350 20 0
64 1310 690 20 0
74 660 690 20 0
69 1310 0 20 0
48 330 350 20 0
57 330 350 20 0
66 1310 690 20 0
67 330 350 20 0
69 330 350 20 0
71 1310 0 20 0
48 330 350 20 0
55 330 350 20 0
64 1310 690 20 0
74 660 690 20 0
69 1970 0 20 0
48 330 350 20 0
'''

def postprocess(txt, path):
    # assert txt.startswith(prompt)
    txt = txt.split('\n\n')[-1]

    tracks = {}

    now = 0
    # we need to ignore the invalid output by the model
    try:
        for line in txt.split('\n'):
            pitch, duration, wait, velocity, instrument = line.split()
            pitch, duration, wait, velocity = [int(x) for x in [pitch, duration, wait, velocity]]
            if instrument not in tracks:
                tracks[instrument] = symusic.core.TrackSecond()
                if instrument != 'drum':
                    tracks[instrument].program = int(instrument)
                else:
                    tracks[instrument].is_drum = True
            # Eg. Note(time=7.47, duration=5.25, pitch=43, velocity=64, ttype='Second')
            tracks[instrument].notes.append(symusic.core.NoteSecond(
                time=now/1000,
                duration=duration/1000,
                pitch=int(pitch),
                velocity=int(velocity * 4),
            ))
            now += wait
    except Exception as e:
        print('Postprocess: Ignored error:', e)
    
    print(f'Postprocess: Got {sum(len(track.notes) for track in tracks.values())} notes')

    try:
        score = symusic.Score(ttype='Second')
        score.tracks.extend(tracks.values())
        score.dump_midi(path)
    except Exception as e:
        print('Postprocess: Ignored postprocessing error:', e)

postprocess(txt, './result.mid')

Similarly, to convert a midi file to the text representation:

def preprocess(path):
    # turn the midi into a custom format and write it to ./example/output.txt
    # midi files may be broken
    try:
        score = symusic.Score(path, ttype='Second')
    except Exception as e:
        print('Ignored midi loading error:', e)
        return ''
    
    # prolong notes to the end of the current pedal
    score = score.copy()
    for track in score.tracks:
        notes = track.notes
        pedals = track.pedals
        track.pedals = []
        j = 0
        for i, note in enumerate(notes):
            while j < len(pedals) and pedals[j].time + pedals[j].duration < note.time:
                j += 1
            if j < len(pedals) and pedals[j].time <= note.time <= pedals[j].time + pedals[j].duration:
                # adjust the duration
                note.duration = max(
                    note.duration, 
                    pedals[j].time + pedals[j].duration - note.time,
                )
    
    notes = []
    for track in score.tracks:
        instrument = str(track.program)   # program id. `instrument` is always a string.
        if track.is_drum:
            instrument = 'drum'
        for note in track.notes:
            notes.append((note.time, note.duration, note.pitch, note.velocity, instrument))
    # dedup
    notes = list({
        (time, duration, pitch): (time, duration, pitch, velocity, instrument) 
        for time, duration, pitch, velocity, instrument in notes
    }.values())
    # merge channels, sort by start time. If notes start at the same time, the higher pitch comes first.
    notes.sort(key=lambda x: (x[0], -x[2]))
    # Translate start time to the delta time format: 
    # ie. 'pitch duration wait', in milliseconds.
    notes1 = []
    
    txt = []
    for note in notes1:
        txt.append(' '.join(map(str, note)))
    txt = '\n'.join(txt)
    return txt

txt = preprocess('./test.mid')
Downloads last month
200
Safetensors
Model size
1B params
Tensor type
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for dx2102/llama-midi

Finetuned
(1157)
this model
Quantizations
1 model

Datasets used to train dx2102/llama-midi

Space using dx2102/llama-midi 1