{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# JASCO\n", "Welcome to JASCO's demo jupyter notebook. \n", "Here you will find a self-contained example of how to use JASCO for temporally controlled music generation.\n", "\n", "You can choose a model from the following selection:\n", "1. facebook/jasco-chords-drums-400M - 10s music generation conditioned on text, chords and drums, 400M parameters\n", "2. facebook/jasco-chords-drums-1B - 10s music generation conditioned on text, chords and drums, 1B parameters\n", "3. facebook/jasco-chords-drums-melody-400M - 10s music generation conditioned on text, chords, drums and melody, 400M parameters\n", "4. facebook/jasco-chords-drums-melody-1B - 10s music generation conditioned on text, chords, drums and melody, 1B parameters\n", "\n", "First, we start by initializing the JASCO model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os \n", "from audiocraft.models import JASCO\n", "\n", "model = JASCO.get_pretrained('facebook/jasco-chords-drums-melody-400M', chords_mapping_path='../assets/chord_to_index_mapping.pkl')\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, let us configure the generation parameters. Specifically, you can control the following:\n", "* `cfg_coef_all` (float, optional): Coefficient used for classifier free guidance - fully conditional term. \n", " Defaults to 5.0.\n", "* `cfg_coef_txt` (float, optional): Coefficient used for classifier free guidance - additional text conditional term. \n", " Defaults to 0.0.\n", "\n", "When left unchanged, JASCO will revert to its default parameters." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model.set_generation_params(\n", " cfg_coef_all=0.0,\n", " cfg_coef_txt=5.0\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we can go ahead and start generating music given textual prompts." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Text-conditional Generation" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from audiocraft.utils.notebook import display_audio\n", "\n", "# set textual prompt\n", "text = \"Funky groove with electric piano playing blue chords rhythmically\"\n", "\n", "# run the model\n", "print(\"Generating...\") \n", "output = model.generate(descriptions=[text], progress=True)\n", "\n", "# display the result\n", "print(f\"Text: {text}\\n\")\n", "display_audio(output, sample_rate=model.compression_model.sample_rate)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we can start adding temporal controls! We begin with conditioning on chord progressions:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Chords-conditional Generation" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model.set_generation_params(\n", " cfg_coef_all=1.5,\n", " cfg_coef_txt=3.0\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from audiocraft.utils.notebook import display_audio\n", "\n", "# set textual prompt\n", "text = \"Strings, woodwind, orchestral, symphony.\"\n", "\n", "# define chord progression\n", "chords = [('C', 0.0), ('D', 2.0), ('F', 4.0), ('Ab', 6.0), ('Bb', 7.0), ('C', 8.0)]\n", "\n", "# run the model\n", "print(\"Generating...\")\n", "output = model.generate_music(descriptions=[text], chords=chords, progress=True)\n", "\n", "# display the result\n", "print(f'Text: {text}')\n", "print(f'Chord progression: {chords}')\n", "display_audio(output, sample_rate=model.compression_model.sample_rate)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we can condition the generation on drum tracks:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Drums-conditional Generation" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torchaudio\n", "from audiocraft.utils.notebook import display_audio\n", "\n", "\n", "# load drum prompt\n", "drums_waveform, sr = torchaudio.load(\"../assets/sep_drums_1.mp3\")\n", "\n", "# set textual prompt \n", "text = \"distortion guitars, heavy rock, catchy beat\"\n", "\n", "# run the model\n", "print(\"Generating...\")\n", "output = model.generate_music(\n", " descriptions=[text],\n", " drums_wav=drums_waveform,\n", " drums_sample_rate=sr,\n", " progress=True\n", ")\n", "\n", "# display the result\n", "print('drum prompt:')\n", "display_audio(drums_waveform, sample_rate=sr)\n", "print(f'Text: {text}')\n", "display_audio(output, sample_rate=model.compression_model.sample_rate)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also combine multiple temporal controls! Let's move on to generating with both chords and drums conditioning:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Drums + Chords conditioning" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torchaudio\n", "from audiocraft.utils.notebook import display_audio\n", "\n", "\n", "# load drum prompt\n", "drums_waveform, sr = torchaudio.load(\"../assets/sep_drums_1.mp3\")\n", "\n", "# set textual prompt \n", "text = \"string quartet, orchestral, dramatic\"\n", "\n", "# define chord progression\n", "chords = [('C', 0.0), ('D', 2.0), ('F', 4.0), ('Ab', 6.0), ('Bb', 7.0), ('C', 8.0)]\n", "\n", "# run the model\n", "print(\"Generating...\")\n", "output = model.generate_music(\n", " descriptions=[text],\n", " drums_wav=drums_waveform,\n", " drums_sample_rate=sr,\n", " chords=chords,\n", " progress=True\n", ")\n", "\n", "# display the result\n", "print('drum prompt:')\n", "display_audio(drums_waveform, sample_rate=sr)\n", "print(f'Chord progression: {chords}')\n", "print(f'Text: {text}')\n", "display_audio(output, sample_rate=model.compression_model.sample_rate)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Melody + Drums + Chords conditioning - inference example" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "from demucs import pretrained\n", "from demucs.apply import apply_model\n", "from demucs.audio import convert_audio\n", "import torch\n", "from audiocraft.utils.notebook import display_audio\n", "import matplotlib.pyplot as plt\n", "\n", "# --------------------------\n", "# First, choose file to load\n", "# --------------------------\n", "fnames = ['salience_1', 'salience_2']\n", "chords = [\n", " [('N', 0.0), ('Eb7', 1.088000000), ('C#', 4.352000000), ('D', 4.864000000), ('Dm7', 6.720000000), ('G7', 8.256000000), ('Am7b5/G', 9.152000000)], # for salience 1\n", " [('N', 0.0), ('C', 0.320000000), ('Dm7', 3.456000000), ('Am', 4.608000000), ('F', 8.320000000), ('C', 9.216000000)] # for salience 2\n", "]\n", "file_idx = 0 # either 0 or 1\n", "\n", "\n", "# ------------------------------------\n", "# display audio, melody map and chords\n", "# ------------------------------------\n", "def plot_chromagram(tensor):\n", " # Check if tensor is a PyTorch tensor\n", " if not torch.is_tensor(tensor):\n", " raise ValueError('Input should be a PyTorch tensor')\n", " tensor = tensor.numpy().T # C, T\n", " plt.figure(figsize=(20, 20))\n", " plt.imshow(tensor, cmap='binary', interpolation='nearest', origin='lower')\n", " plt.show()\n", "\n", "# load salience and display the corresponding wav\n", "melody_prompt_wav, melody_prompt_sr = torchaudio.load(f\"../assets/{fnames[file_idx]}.wav\")\n", "print(\"Source melody:\")\n", "display_audio(melody_prompt_wav, sample_rate=melody_prompt_sr)\n", "melody = torch.load(f\"../assets/{fnames[file_idx]}.th\", weights_only=True)\n", "plot_chromagram(melody)\n", "print(\"Chords:\")\n", "print(chords[file_idx])\n", "\n", "# --------------------------------------------------\n", "# use demucs to seperate the drums stem from src mix\n", "# --------------------------------------------------\n", "def _get_drums_stem(wav: torch.Tensor, sample_rate: int) -> torch.Tensor:\n", " \"\"\"Get parts of the wav that holds the drums, extracting the main stems from the wav.\"\"\"\n", " demucs_model = pretrained.get_model('htdemucs').to('cuda')\n", " wav = convert_audio(\n", " wav, sample_rate, demucs_model.samplerate, demucs_model.audio_channels) # type: ignore\n", " stems = apply_model(demucs_model, wav.cuda().unsqueeze(0), device='cuda').squeeze(0)\n", " drum_stem = stems[demucs_model.sources.index('drums')] # extract relevant stems for drums conditioning\n", " return convert_audio(drum_stem.cpu(), demucs_model.samplerate, sample_rate, 1) # type: ignore\n", "drums_wav = _get_drums_stem(melody_prompt_wav, melody_prompt_sr)\n", "print(\"Separated drums:\")\n", "display_audio(drums_wav, sample_rate=melody_prompt_sr)\n", "\n", "# ----------------------------------\n", "# Generate using the loaded controls\n", "# ----------------------------------\n", "# these are free-form texts written randomly\n", "texts = [\n", " '90s rock with heavy drums and hammond',\n", " '80s pop with groovy synth bass and drum machine',\n", " 'folk song with leading accordion',\n", "]\n", "\n", "print(\"Generating...\")\n", "# replacing dynammic solver with simple euler solver\n", "model.set_generation_params(cfg_coef_all=1.5, cfg_coef_txt=2.5, euler=True, euler_steps=50) # manually set with euler solver\n", "output = model.generate_music(\n", " descriptions=texts,\n", " chords=chords[file_idx],\n", " drums_wav=drums_wav,\n", " drums_sample_rate=melody_prompt_sr,\n", " melody_salience_matrix=melody.permute(1, 0),\n", " progress=True\n", ")\n", "display_audio(output, sample_rate=model.compression_model.sample_rate)" ] } ], "metadata": { "kernelspec": { "display_name": "jasco_dev", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.19" } }, "nbformat": 4, "nbformat_minor": 2 }