Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 76 additions & 3 deletions lib_tflite_micro/src/tflite-xcore-kernels/xcore_mean.cc
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
// Copyright (c) 2023, XMOS Ltd, All rights reserved

#include "xcore_custom_options.h"
#include "../thread_call.h"
#include "xcore_config.h"
#include "xcore_utils.h"
extern "C" {
#include "lib_nn/api/nn_operator.h"
#include "lib_nn/api/nn_layers.h"
}

Expand All @@ -12,6 +15,12 @@ namespace micro {
namespace xcore {
namespace mean {

typedef struct MeanWorkerArg0 {
int input_offset;
int output_offset;
int start_dim_size;
};

// This is the struct that contains the data required by the operator
struct MeanOpData {
int start_dim_size;
Expand All @@ -20,8 +29,29 @@ struct MeanOpData {
float in_zero_point;
float out_zero_point;
float scale_mul;
int tc;
MeanWorkerArg0 arg0[XCORE_MAX_NUM_THREADS];
};

struct MeanShared {
int8_t *input;
int8_t *output;
MeanOpData *op_data;
};

extern "C" {
void mean_int8_thread_worker(void *shared, void *arg0, void *arg1) {
MeanWorkerArg0 *arg = static_cast<MeanWorkerArg0 *>(arg0);
(void) arg1;
auto sd = static_cast<MeanShared *>(shared);
int8_t *input = &sd->input[arg->input_offset];
int8_t *output = &sd->output[arg->output_offset];
mean_int8(input, output, arg->start_dim_size, sd->op_data->mean_dim_size,
sd->op_data->end_dim_size, sd->op_data->in_zero_point,
sd->op_data->out_zero_point, sd->op_data->scale_mul);
}
}

void *Init(TfLiteContext *context, const char *buffer, size_t length) {
auto op_data = construct_persistent_object<MeanOpData>(context);

Expand All @@ -37,6 +67,20 @@ void *Init(TfLiteContext *context, const char *buffer, size_t length) {

// Does all the requests for scratches
TfLiteStatus Prepare(TfLiteContext *context, TfLiteNode *node) {
auto op_data = static_cast<MeanOpData *>(node->user_data);
MicroContext *micro_context = GetMicroContext(context);
xc_context_config_t *xc_config = reinterpret_cast<xc_context_config_t *>(
micro_context->external_context());
int s[XCORE_MAX_NUM_THREADS];
int e[XCORE_MAX_NUM_THREADS];
op_data->tc = calculateAlignedThreadSplit(
xc_config->model_thread_count, op_data->start_dim_size, s, e);
// Turn start and end into input and output offset
for (int t = 0; t < op_data->tc; ++t) {
op_data->arg0[t].input_offset = s[t] * op_data->mean_dim_size * op_data->end_dim_size;
op_data->arg0[t].output_offset = s[t] * op_data->end_dim_size;
op_data->arg0[t].start_dim_size = e[t]-s[t];
}
return kTfLiteOk;
}

Expand All @@ -52,9 +96,38 @@ TfLiteStatus Eval(TfLiteContext *context, TfLiteNode *node) {
// Pointers to data in In/Out Tensors
int8_t *out_data = tflite_micro::micro::GetTensorData<int8_t>(output);
const int8_t *in_data = tflite_micro::micro::GetTensorData<int8_t>(input);
mean_int8(in_data, out_data, op_data->start_dim_size, op_data->mean_dim_size,
op_data->end_dim_size, op_data->in_zero_point,
op_data->out_zero_point, op_data->scale_mul);
MicroContext *micro_context = GetMicroContext(context);
xc_context_config_t *xc_config = reinterpret_cast<xc_context_config_t *>(
micro_context->external_context());
const int tc = op_data->tc;
if (tc == 1 && input->type == kTfLiteInt8) {
mean_int8(in_data, out_data, op_data->start_dim_size, op_data->mean_dim_size,
op_data->end_dim_size, op_data->in_zero_point,
op_data->out_zero_point, op_data->scale_mul);
return kTfLiteOk;
}
MeanShared shared_data;
shared_data.input = const_cast<int8_t *>(in_data);
shared_data.output = out_data;
shared_data.op_data = op_data;
for (int t = 0; t < tc - 1; t++) {
thread_variable_setup((void *)&op_data->arg0[t], nullptr,
xc_config->thread_info.thread_ids.id[t]);
}

thread_function_pointer_t fn;
switch (input->type) {
case kTfLiteInt8: {
fn = mean_int8_thread_worker;
break;
}
default: {
return kTfLiteError;
}
}

thread_call((void *)&shared_data, &op_data->arg0[tc - 1], nullptr,
(thread_function_pointer_t)fn, &xc_config->thread_info);

return kTfLiteOk;
}
Expand Down