-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathpredict.py
More file actions
119 lines (100 loc) · 3.8 KB
/
predict.py
File metadata and controls
119 lines (100 loc) · 3.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# Prediction interface for Cog ⚙️
# https://github.com/replicate/cog/blob/main/docs/python.md
import json
from os.path import abspath, dirname
import torch
import torchaudio
from cog import BasePredictor, Input, Path
from einops import rearrange
from stable_audio_tools.inference.generation import generate_diffusion_cond
from stable_audio_tools.models.factory import create_model_from_config
from stable_audio_tools.models.utils import load_ckpt_state_dict
from weights_downloader import WeightsDownloader
MODEL_PATH = "/src/models"
WEIGHTS_STR = "stable-audio-open-1.0"
class Predictor(BasePredictor):
def setup(self):
"""Load the model into memory to make running multiple predictions efficient"""
self.model, self.model_config = self._load_model()
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = self.model.to(device=self.device)
def _load_model(
self,
):
weights_downloader = WeightsDownloader()
weights_downloader.download_weights(WEIGHTS_STR, MODEL_PATH)
model_config_path = f"{MODEL_PATH}/{WEIGHTS_STR}/model_config.json"
model_ckpt_path = f"{MODEL_PATH}/{WEIGHTS_STR}/model.ckpt"
with open(model_config_path) as f:
model_config = json.load(f)
model = create_model_from_config(model_config)
model.load_state_dict(load_ckpt_state_dict(model_ckpt_path))
return model, model_config
def predict(
self,
prompt: str = Input(),
negative_prompt: str = Input(default=""),
seconds_start: int = Input(default=0),
seconds_total: int = Input(default=8, le=47),
cfg_scale: float = Input(default=6.0),
steps: int = Input(default=100),
seed: int = Input(default=-1),
sampler_type: str = Input(default="dpmpp-3m-sde"),
sigma_min: float = Input(default=0.03),
sigma_max: int = Input(default=500),
init_noise_level: float = Input(default=1.0),
batch_size: int = Input(default=1),
) -> Path:
if not self.model or not self.model_config:
self.model, self.model_config = self._load_model()
sample_rate = self.model_config["sample_rate"]
sample_size = sample_rate * (seconds_total + 1)
if torch.cuda.is_available():
torch.cuda.empty_cache()
print(f"Prompt: {prompt}")
conditioning = [
{
"prompt": prompt,
"seconds_start": seconds_start,
"seconds_total": seconds_total,
}
] * batch_size
if negative_prompt:
negative_conditioning = [
{
"prompt": negative_prompt,
"seconds_start": seconds_start,
"seconds_total": seconds_total,
}
] * batch_size
else:
negative_conditioning = None
seed = int(seed)
audio = generate_diffusion_cond(
self.model,
conditioning=conditioning,
negative_conditioning=negative_conditioning,
steps=steps,
cfg_scale=cfg_scale,
batch_size=batch_size,
sample_size=sample_size,
sample_rate=sample_rate,
seed=seed,
device=self.device,
sampler_type=sampler_type,
sigma_min=sigma_min,
sigma_max=sigma_max,
init_noise_level=init_noise_level,
)
audio = rearrange(audio, "b d n -> d (b n)")
audio = (
audio.to(torch.float32)
.div(torch.max(torch.abs(audio)))
.clamp(-1, 1)
.mul(32767)
.to(torch.int16)
.cpu()
)
wav_path = "output.wav"
torchaudio.save(wav_path, audio, sample_rate)
return Path(wav_path)