Skip to content

FAIL: Using cache and enabling back past_key_values cache #58

@msedalatzadeh

Description

@msedalatzadeh

I tried various models of transformers.js and those that support past_key_values does not actually handle it. I face several issues:

  1. The default past_key_values are gpuBuffer tensors and ONNX requires cpu tensors as input
  2. Downloading past_key_values into cpu using downloader method and running again will run into dimension inconsistency problems. Basically we need to feed input_ids, attention_mask, position_ids into the model.generate(), I tried various shapes and all failed:
  • Assume the past_key_value.dims[2] = past_length and the input_ids.dims[1] = full_length. I tweaked all combinations of each input being past_length or full_length or full_length - past_length or simply 1 (one token). None worked.

Please share a working example of transformers.js with past_key_values enabled.

Here is my code:

const full_inputs = tokenizer.apply_chat_template(messages, {
  add_generation_prompt: true,
  return_dict: true
});


for (const key in past_gpu_kv) {
  if (past_gpu_kv[key]?.ort_tensor) {
    past_kv[key] = await convertToCPUTensor(past_gpu_kv[key].ort_tensor);
  }
}

const { past_key_values, sequences } = await model.generate({
  ...inputs,
  past_key_values: past_kv,
  use_cache: true,
  do_sample: false,
  top_k: 3,
  temperature: 0.2,
  max_new_tokens: 1024,
  streamer,
  stopping_criteria,
  return_dict_in_generate: true,
});
async function convertToCPUTensor(ortTensor) {
  if (!ortTensor || typeof ortTensor.downloader !== 'function') {
    throw new Error('Invalid ort_tensor: missing downloader method');
  }

  // Download the data from GPU
  const rawData = await ortTensor.downloader(); // usually a Float16Array or Float32Array

  // Check the tensor type and convert to Float32Array if it's float16
  let data = rawData;
  let dtype = ortTensor.type;

  if (dtype === 'float16') {
    data = Float16Array.from(rawData); // Ensure data remains float16
    dtype = 'float16';
  }

  return new Tensor(dtype, data, ortTensor.dims);
}
function buildInputsForGenerate(full_inputs, past_key_values_cache, modelKey) {
  const input_ids_tensor = full_inputs.input_ids;

  if (!past_key_values_cache[modelKey]) {
    return full_inputs;
  }

  const seq_len = input_ids_tensor.dims[1];
  if (seq_len === 0) {
    throw new Error("input_ids is empty — can't slice last token.");
  }
  // Use past key dims to get cached length
  const past = past_key_values_cache[modelKey];
  const past_len = past['past_key_values.0.key'].dims[2];
  const new_len = seq_len - past_len;

  const input_ids = input_ids_tensor.slice([0, 1], [seq_len - 1, seq_len]);
  
  const attention_mask_length = seq_len + 1;
  const attention_mask = new Tensor(
    "int64",
    BigInt64Array.from([
      //...Array(past_len).fill(BigInt(0)),       // Mask out past tokens
      ...Array(attention_mask_length).fill(BigInt(1)),        // Attend only to new tokens
    ]),
    [1, attention_mask_length]
  );

  const position_ids = new Tensor(
    "int64",
    BigInt64Array.from([...Array(new_len).keys()].map(i => BigInt(past_len + i))),
    [1, new_len]
  );

  return {
    input_ids,
    attention_mask,
    position_ids,
  };
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions