diff --git a/src/models.js b/src/models.js index bb0c0ff1f..67c7f23ab 100644 --- a/src/models.js +++ b/src/models.js @@ -3389,6 +3389,11 @@ export class WhisperPreTrainedModel extends PreTrainedModel { */ export class WhisperModel extends WhisperPreTrainedModel { } +/** + * Whisper Encoder Model with a sequence classification head on top + * (a linear layer over the pooled output) for tasks like SUPERB Keyword Spotting. + */ +export class WhisperForAudioClassification extends WhisperPreTrainedModel { } /** * WhisperForConditionalGeneration class for generating conditional outputs from Whisper models. @@ -8349,6 +8354,7 @@ const MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = new Map([ ['wavlm', ['WavLMForSequenceClassification', WavLMForSequenceClassification]], ['hubert', ['HubertForSequenceClassification', HubertForSequenceClassification]], ['audio-spectrogram-transformer', ['ASTForAudioClassification', ASTForAudioClassification]], + ['whisper', ['WhisperForAudioClassification', WhisperForAudioClassification]], ]); const MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = new Map([ diff --git a/src/models/whisper/feature_extraction_whisper.js b/src/models/whisper/feature_extraction_whisper.js index a52a18e33..ad455d17f 100644 --- a/src/models/whisper/feature_extraction_whisper.js +++ b/src/models/whisper/feature_extraction_whisper.js @@ -70,8 +70,9 @@ export class WhisperFeatureExtractor extends FeatureExtractor { const length = max_length ?? this.config.n_samples; if (audio.length > length) { if (audio.length > this.config.n_samples) { + const seconds = Math.floor(length / this.config.sampling_rate); console.warn( - "Attempting to extract features for audio longer than 30 seconds. " + + `Attempting to extract features for audio longer than ${seconds} seconds. ` + "If using a pipeline to extract transcript from a long audio clip, " + "remember to specify `chunk_length_s` and/or `stride_length_s`." ); diff --git a/src/pipelines.js b/src/pipelines.js index 1032a7dea..beb6ca888 100644 --- a/src/pipelines.js +++ b/src/pipelines.js @@ -1515,9 +1515,13 @@ export class AudioClassificationPipeline extends (/** @type {new (options: Audio const output = await this.model(inputs); const logits = output.logits[0]; + const probabilities = logits.data.length > 1 + ? softmax(logits.data) + : logits.sigmoid().data; // Only one label, so we assume it's a binary classification + const scores = await topk(new Tensor( 'float32', - softmax(logits.data), + probabilities, logits.dims, ), top_k);