-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathDiffSingerTensorCache.cs
More file actions
312 lines (285 loc) · 16.1 KB
/
Copy pathDiffSingerTensorCache.cs
File metadata and controls
312 lines (285 loc) · 16.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
using System;
using System.Collections.Generic;
using System.IO;
using System.IO.Hashing;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;
using TuneLab.Foundation;
using TuneLab.SDK;
namespace DiffSingerForTuneLab;
// DiffSinger 张量缓存:把一个 ONNX 模型调用的输出按「模型文件哈希 + 序列化输入」为键缓存到磁盘,
// 反复合成(撤销重做、重开工程、改动不影响某块、跨块/跨说话人共享 linguistic 等)时直接复用、免重算。
// 忠实移植 OpenUtau DiffSingerCache(序列化/反序列化、文件格式与按 name 排序求键一致),差异仅在:
// · 哈希用 System.IO.Hashing.XxHash64(OpenUtau 用 K4os.Hash.xxHash),数值不同但本插件不需跨工具兼容;
// · 缓存目录取插件独立用户数据根 UserDataRoot/Cache(OpenUtau 用 PathManager.Inst.CachePath)。
// 另附编排封装:Run(建键→Load→未命中则模型 Run + Save,返回脱离原生内存的托管张量)、
// Clone(把模型原生输出深拷成托管,供未命中/禁用时安全返回)、HashFile(模型 identifier)、EnforceSizeLimit(LRU 逐出)。
public sealed class DiffSingerTensorCache
{
const string FormatHeader = "TENSORCACHE";
readonly ulong mHash;
readonly string mFilename;
public ulong Hash => mHash;
public string Filename => mFilename;
// 缓存目录:插件用户数据根下的 Cache(与 Voices/Vocoders 并列)。
public static string CacheDirectory => Path.Combine(DiffSingerDeclarations.UserDataRoot, "Cache");
public DiffSingerTensorCache(ulong identifier, IReadOnlyCollection<NamedOnnxValue> inputs)
{
using var stream = new MemoryStream();
using (var writer = new BinaryWriter(stream, Encoding.UTF8, leaveOpen: true))
{
writer.Write(identifier);
foreach (var onnxValue in inputs.OrderBy(v => v.Name, StringComparer.InvariantCulture))
SerializeNamedOnnxValue(writer, onnxValue);
}
mHash = XxHash64.HashToUInt64(stream.ToArray());
mFilename = $"ds-{mHash:x16}.tensorcache";
}
// —— 编排封装 ——
// 推理串行化锁(进程级单锁):DirectML EP 下 InferenceSession.Run() 不可并发——约束是设备级的,
// 不仅同一会话,同一 GPU 上不同会话并发 Run 也会崩/出错。宿主会为每条 part 各开合成会话,
// 且多会话经引擎级缓存共享声学/声码器会话,故必须跨会话全局串行(按会话实例上锁挡不住跨 voiceId 的并发)。
// CPU EP 虽线程安全,但单次 Run 已用满 intra-op 线程池,并发 Run 只会过订阅线程;全局串行对其亦中性偏好,
// 换取「无论 provider / 会话数 / 模型束都正确」的简单性。锁只罩 Run 本身,缓存 Load/Save 等磁盘 IO 在锁外。
static readonly object sRunLock = new();
// 已退役会话登记:随模型缓存释放的会话在 sRunLock 内先登记于此,令其后任何 Run 干净抛出 ObjectDisposedException,
// 而非把已释放的原生句柄喂进 model.Run(→ AccessViolationException)。根治关闭 / 换执行设备时的并发释放崩溃。
// 用 ConditionalWeakTable 持弱引用:会话被缓存解引用后条目随 GC 回收,不泄漏;按会话实例隔离,
// 故换设备新建的会话天然不在表内、不受影响。value 仅作存在性标记。
static readonly ConditionalWeakTable<InferenceSession, object> sRetired = new();
static readonly object sRetiredMarker = new();
static IDisposableReadOnlyCollection<DisposableNamedOnnxValue> RunSerialized(
InferenceSession model, IReadOnlyCollection<NamedOnnxValue> inputs)
{
lock (sRunLock)
{
if (sRetired.TryGetValue(model, out _))
throw new ObjectDisposedException(nameof(InferenceSession), "DiffSinger:模型会话已随缓存释放,推理取消。");
return model.Run(inputs);
}
}
// 在全局推理锁内退役并释放一组会话:持锁即保证当前无 Run 在飞(Run 全程持同一把锁),故释放不会与在飞推理并发
// 触发 use-after-free;释放前先登记退役,杜绝多级管线在本次释放后还想跑下一阶段时触碰已释放会话。
// 供模型束 / 预测器 / 缓存释放各自的会话时调用——分批各取一次锁即可(跨批的 Run 仍受同锁串行保护)。
public static void RetireAndDispose(IEnumerable<InferenceSession> sessions)
{
lock (sRunLock)
foreach (var session in sessions)
{
sRetired.AddOrUpdate(session, sRetiredMarker);
session.Dispose();
}
}
// 跑一个模型并经缓存:命中直接返回缓存输出;未命中则 Run + Save。返回的张量均为托管 DenseTensor(脱离原生内存,可在调用处延后读取)。
// enabled=false 时跳过磁盘缓存,但仍把原生输出深拷为托管返回(调用方约定可在 Run 作用域外读取输出)。
public static IReadOnlyList<NamedOnnxValue> Run(
InferenceSession model, ulong identifier, IReadOnlyCollection<NamedOnnxValue> inputs, bool enabled)
{
if (!enabled)
{
using var raw = RunSerialized(model, inputs);
return Clone(raw);
}
var cache = new DiffSingerTensorCache(identifier, inputs);
var loaded = cache.Load();
if (loaded != null)
return loaded;
using var run = RunSerialized(model, inputs);
var result = Clone(run); // 先脱离原生内存,再落盘(Save 读托管张量即可)
cache.Save(result);
return result;
}
// 把(可能由原生 OrtValue 支撑的)输出深拷为托管 DenseTensor,使其在原生集合 Dispose 后仍可安全读取。
// 复用序列化/反序列化的类型分支(往返一次内存流),零重复代码;相对扩散推理开销可忽略。
public static List<NamedOnnxValue> Clone(IEnumerable<NamedOnnxValue> values)
{
var list = new List<NamedOnnxValue>();
foreach (var v in values)
{
using var ms = new MemoryStream();
using (var w = new BinaryWriter(ms, Encoding.UTF8, leaveOpen: true))
SerializeNamedOnnxValue(w, v);
ms.Position = 0;
using var r = new BinaryReader(ms);
list.Add(DeserializeNamedOnnxValue(r));
}
return list;
}
// 模型 identifier:.onnx 文件内容的 XxHash64(流式、不整体载入内存)。加载时算一次缓存进字段,
// 用作缓存键的一部分,区分不同模型(同输入不同权重不撞键),且模型文件更换即自动失效。
public static ulong HashFile(string path)
{
var h = new XxHash64();
using var fs = File.OpenRead(path);
h.Append(fs);
return h.GetCurrentHashAsUInt64();
}
// LRU 体积上限逐出:缓存目录超过上限时,按最近访问时间删最旧的 .tensorcache 直到回落。maxSizeMb<=0 视作不限制。
// 尽力而为,任何 IO 异常吞掉(逐出失败不应影响合成)。
public static void EnforceSizeLimit(long maxSizeMb)
{
if (maxSizeMb <= 0)
return;
try
{
var dir = CacheDirectory;
if (!Directory.Exists(dir))
return;
var files = new DirectoryInfo(dir).GetFiles("*.tensorcache");
long total = files.Sum(f => f.Length);
long limit = maxSizeMb * 1024L * 1024L;
if (total <= limit)
return;
foreach (var f in files.OrderBy(f => f.LastAccessTimeUtc))
{
try { long len = f.Length; f.Delete(); total -= len; } catch { }
if (total <= limit)
break;
}
}
catch { }
}
public IReadOnlyList<NamedOnnxValue>? Load()
{
var cachePath = Path.Join(CacheDirectory, mFilename);
if (!File.Exists(cachePath))
return null;
var result = new List<NamedOnnxValue>();
try
{
using (var stream = new FileStream(cachePath, FileMode.Open, FileAccess.Read))
using (var reader = new BinaryReader(stream))
{
if (reader.ReadString() != FormatHeader)
throw new InvalidDataException($"[TensorCache] 缓存文件头异常:{mFilename}。");
var count = reader.ReadInt32();
for (var i = 0; i < count; ++i)
result.Add(DeserializeNamedOnnxValue(reader));
}
}
catch (Exception e)
{
TuneLabContext.Global.GetLogger().Warning($"DiffSinger:反序列化缓存 {mFilename} 失败、丢弃重算:{e.Message}");
Delete();
return null;
}
// 命中即「访问」:显式刷新访问时间,令 LRU 逐出以真实使用近度排序(不依赖 NTFS 自动 last-access 策略)。
try { File.SetLastAccessTimeUtc(cachePath, DateTime.UtcNow); } catch { }
return result;
}
public void Delete()
{
var cachePath = Path.Join(CacheDirectory, mFilename);
if (File.Exists(cachePath))
{
try { File.Delete(cachePath); } catch { }
}
}
public void Save(IReadOnlyCollection<NamedOnnxValue> outputs)
{
Directory.CreateDirectory(CacheDirectory);
var cachePath = Path.Join(CacheDirectory, mFilename);
using var stream = new FileStream(cachePath, FileMode.Create, FileAccess.Write);
using var writer = new BinaryWriter(stream);
writer.Write(FormatHeader);
writer.Write(outputs.Count);
foreach (var onnxValue in outputs)
SerializeNamedOnnxValue(writer, onnxValue);
}
static void SerializeNamedOnnxValue(BinaryWriter writer, NamedOnnxValue namedOnnxValue)
{
if (namedOnnxValue.ValueType != OnnxValueType.ONNX_TYPE_TENSOR)
throw new NotSupportedException(
$"[TensorCache] 仅支持张量类型 {OnnxValueType.ONNX_TYPE_TENSOR},遇 {namedOnnxValue.ValueType}。");
writer.Write(namedOnnxValue.Name);
var tensorBase = (TensorBase)namedOnnxValue.Value;
var elementType = tensorBase.GetTypeInfo().ElementType;
writer.Write((int)elementType);
switch (elementType)
{
case TensorElementType.Float: SerializeTensor(writer, namedOnnxValue.AsTensor<float>()); break;
case TensorElementType.UInt8: SerializeTensor(writer, namedOnnxValue.AsTensor<byte>()); break;
case TensorElementType.Int8: SerializeTensor(writer, namedOnnxValue.AsTensor<sbyte>()); break;
case TensorElementType.UInt16: SerializeTensor(writer, namedOnnxValue.AsTensor<ushort>()); break;
case TensorElementType.Int16: SerializeTensor(writer, namedOnnxValue.AsTensor<short>()); break;
case TensorElementType.Int32: SerializeTensor(writer, namedOnnxValue.AsTensor<int>()); break;
case TensorElementType.Int64: SerializeTensor(writer, namedOnnxValue.AsTensor<long>()); break;
case TensorElementType.String: SerializeTensor(writer, namedOnnxValue.AsTensor<string>()); break;
case TensorElementType.Bool: SerializeTensor(writer, namedOnnxValue.AsTensor<bool>()); break;
case TensorElementType.Float16: SerializeTensor(writer, namedOnnxValue.AsTensor<Float16>()); break;
case TensorElementType.Double: SerializeTensor(writer, namedOnnxValue.AsTensor<double>()); break;
case TensorElementType.UInt32: SerializeTensor(writer, namedOnnxValue.AsTensor<uint>()); break;
case TensorElementType.UInt64: SerializeTensor(writer, namedOnnxValue.AsTensor<ulong>()); break;
case TensorElementType.BFloat16: SerializeTensor(writer, namedOnnxValue.AsTensor<BFloat16>()); break;
default:
throw new NotSupportedException($"[TensorCache] 不支持的张量元素类型:{elementType}。");
}
}
static void SerializeTensor<T>(BinaryWriter writer, Tensor<T> tensor)
{
if (tensor.IsReversedStride)
throw new NotSupportedException("[TensorCache] 不支持反序步幅张量。");
writer.Write(tensor.Rank);
foreach (var dim in tensor.Dimensions)
writer.Write(dim);
var size = (int)tensor.Length;
writer.Write(size);
if (typeof(T) == typeof(string))
{
foreach (var element in tensor.ToArray())
writer.Write(element?.ToString() ?? string.Empty);
}
else
{
var data = new byte[size * tensor.GetTypeInfo().TypeSize];
Buffer.BlockCopy(tensor.ToArray(), 0, data, 0, data.Length);
writer.Write(data);
}
}
static NamedOnnxValue DeserializeNamedOnnxValue(BinaryReader reader)
{
var name = reader.ReadString();
var dtype = (TensorElementType)reader.ReadInt32();
var rank = reader.ReadInt32();
int[] shape = new int[rank];
for (var i = 0; i < rank; ++i)
shape[i] = reader.ReadInt32();
var size = reader.ReadInt32();
switch (dtype)
{
case TensorElementType.Float: return NamedOnnxValue.CreateFromTensor(name, DeserializeTensor<float>(reader, size, sizeof(float), shape));
case TensorElementType.UInt8: return NamedOnnxValue.CreateFromTensor(name, DeserializeTensor<byte>(reader, size, sizeof(byte), shape));
case TensorElementType.Int8: return NamedOnnxValue.CreateFromTensor(name, DeserializeTensor<sbyte>(reader, size, sizeof(sbyte), shape));
case TensorElementType.UInt16: return NamedOnnxValue.CreateFromTensor(name, DeserializeTensor<ushort>(reader, size, sizeof(ushort), shape));
case TensorElementType.Int16: return NamedOnnxValue.CreateFromTensor(name, DeserializeTensor<short>(reader, size, sizeof(short), shape));
case TensorElementType.Int32: return NamedOnnxValue.CreateFromTensor(name, DeserializeTensor<int>(reader, size, sizeof(int), shape));
case TensorElementType.Int64: return NamedOnnxValue.CreateFromTensor(name, DeserializeTensor<long>(reader, size, sizeof(long), shape));
case TensorElementType.String:
{
Tensor<string> tensor = new DenseTensor<string>(size);
for (var i = 0; i < size; ++i)
tensor[i] = reader.ReadString();
tensor = tensor.Reshape(shape);
return NamedOnnxValue.CreateFromTensor(name, tensor);
}
case TensorElementType.Bool: return NamedOnnxValue.CreateFromTensor(name, DeserializeTensor<bool>(reader, size, sizeof(bool), shape));
case TensorElementType.Float16: return NamedOnnxValue.CreateFromTensor(name, DeserializeTensor<Float16>(reader, size, sizeof(ushort), shape));
case TensorElementType.Double: return NamedOnnxValue.CreateFromTensor(name, DeserializeTensor<double>(reader, size, sizeof(double), shape));
case TensorElementType.UInt32: return NamedOnnxValue.CreateFromTensor(name, DeserializeTensor<uint>(reader, size, sizeof(uint), shape));
case TensorElementType.UInt64: return NamedOnnxValue.CreateFromTensor(name, DeserializeTensor<ulong>(reader, size, sizeof(ulong), shape));
case TensorElementType.BFloat16: return NamedOnnxValue.CreateFromTensor(name, DeserializeTensor<BFloat16>(reader, size, sizeof(ushort), shape));
default:
throw new NotSupportedException($"[TensorCache] 不支持的张量元素类型:{dtype}。");
}
}
static Tensor<T> DeserializeTensor<T>(BinaryReader reader, int size, int typeSize, ReadOnlySpan<int> shape)
{
var bytes = reader.ReadBytes(size * typeSize);
var data = new T[size];
Buffer.BlockCopy(bytes, 0, data, 0, bytes.Length);
return new DenseTensor<T>(data, shape);
}
}