GCC Code Coverage Report
Directory: ../../../ffmpeg/ Exec Total Coverage
File: src/libavfilter/af_arnndn.c Lines: 0 747 0.0 %
Date: 2021-04-14 23:45:22 Branches: 0 706 0.0 %

Line Branch Exec Source
1
/*
2
 * Copyright (c) 2018 Gregor Richards
3
 * Copyright (c) 2017 Mozilla
4
 * Copyright (c) 2005-2009 Xiph.Org Foundation
5
 * Copyright (c) 2007-2008 CSIRO
6
 * Copyright (c) 2008-2011 Octasic Inc.
7
 * Copyright (c) Jean-Marc Valin
8
 * Copyright (c) 2019 Paul B Mahol
9
 *
10
 * Redistribution and use in source and binary forms, with or without
11
 * modification, are permitted provided that the following conditions
12
 * are met:
13
 *
14
 * - Redistributions of source code must retain the above copyright
15
 *   notice, this list of conditions and the following disclaimer.
16
 *
17
 * - Redistributions in binary form must reproduce the above copyright
18
 *   notice, this list of conditions and the following disclaimer in the
19
 *   documentation and/or other materials provided with the distribution.
20
 *
21
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22
 * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23
 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24
 * A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
25
 * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26
 * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27
 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28
 * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29
 * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30
 * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32
 */
33
34
#include <float.h>
35
36
#include "libavutil/avassert.h"
37
#include "libavutil/avstring.h"
38
#include "libavutil/float_dsp.h"
39
#include "libavutil/mem_internal.h"
40
#include "libavutil/opt.h"
41
#include "libavutil/tx.h"
42
#include "avfilter.h"
43
#include "audio.h"
44
#include "filters.h"
45
#include "formats.h"
46
47
#define FRAME_SIZE_SHIFT 2
48
#define FRAME_SIZE (120<<FRAME_SIZE_SHIFT)
49
#define WINDOW_SIZE (2*FRAME_SIZE)
50
#define FREQ_SIZE (FRAME_SIZE + 1)
51
52
#define PITCH_MIN_PERIOD 60
53
#define PITCH_MAX_PERIOD 768
54
#define PITCH_FRAME_SIZE 960
55
#define PITCH_BUF_SIZE (PITCH_MAX_PERIOD+PITCH_FRAME_SIZE)
56
57
#define SQUARE(x) ((x)*(x))
58
59
#define NB_BANDS 22
60
61
#define CEPS_MEM 8
62
#define NB_DELTA_CEPS 6
63
64
#define NB_FEATURES (NB_BANDS+3*NB_DELTA_CEPS+2)
65
66
#define WEIGHTS_SCALE (1.f/256)
67
68
#define MAX_NEURONS 128
69
70
#define ACTIVATION_TANH    0
71
#define ACTIVATION_SIGMOID 1
72
#define ACTIVATION_RELU    2
73
74
#define Q15ONE 1.0f
75
76
typedef struct DenseLayer {
77
    const float *bias;
78
    const float *input_weights;
79
    int nb_inputs;
80
    int nb_neurons;
81
    int activation;
82
} DenseLayer;
83
84
typedef struct GRULayer {
85
    const float *bias;
86
    const float *input_weights;
87
    const float *recurrent_weights;
88
    int nb_inputs;
89
    int nb_neurons;
90
    int activation;
91
} GRULayer;
92
93
typedef struct RNNModel {
94
    int input_dense_size;
95
    const DenseLayer *input_dense;
96
97
    int vad_gru_size;
98
    const GRULayer *vad_gru;
99
100
    int noise_gru_size;
101
    const GRULayer *noise_gru;
102
103
    int denoise_gru_size;
104
    const GRULayer *denoise_gru;
105
106
    int denoise_output_size;
107
    const DenseLayer *denoise_output;
108
109
    int vad_output_size;
110
    const DenseLayer *vad_output;
111
} RNNModel;
112
113
typedef struct RNNState {
114
    float *vad_gru_state;
115
    float *noise_gru_state;
116
    float *denoise_gru_state;
117
    RNNModel *model;
118
} RNNState;
119
120
typedef struct DenoiseState {
121
    float analysis_mem[FRAME_SIZE];
122
    float cepstral_mem[CEPS_MEM][NB_BANDS];
123
    int memid;
124
    DECLARE_ALIGNED(32, float, synthesis_mem)[FRAME_SIZE];
125
    float pitch_buf[PITCH_BUF_SIZE];
126
    float pitch_enh_buf[PITCH_BUF_SIZE];
127
    float last_gain;
128
    int last_period;
129
    float mem_hp_x[2];
130
    float lastg[NB_BANDS];
131
    float history[FRAME_SIZE];
132
    RNNState rnn[2];
133
    AVTXContext *tx, *txi;
134
    av_tx_fn tx_fn, txi_fn;
135
} DenoiseState;
136
137
typedef struct AudioRNNContext {
138
    const AVClass *class;
139
140
    char *model_name;
141
    float mix;
142
143
    int channels;
144
    DenoiseState *st;
145
146
    DECLARE_ALIGNED(32, float, window)[WINDOW_SIZE];
147
    DECLARE_ALIGNED(32, float, dct_table)[FFALIGN(NB_BANDS, 4)][FFALIGN(NB_BANDS, 4)];
148
149
    RNNModel *model[2];
150
151
    AVFloatDSPContext *fdsp;
152
} AudioRNNContext;
153
154
#define F_ACTIVATION_TANH       0
155
#define F_ACTIVATION_SIGMOID    1
156
#define F_ACTIVATION_RELU       2
157
158
static void rnnoise_model_free(RNNModel *model)
159
{
160
#define FREE_MAYBE(ptr) do { if (ptr) free(ptr); } while (0)
161
#define FREE_DENSE(name) do { \
162
    if (model->name) { \
163
        av_free((void *) model->name->input_weights); \
164
        av_free((void *) model->name->bias); \
165
        av_free((void *) model->name); \
166
    } \
167
    } while (0)
168
#define FREE_GRU(name) do { \
169
    if (model->name) { \
170
        av_free((void *) model->name->input_weights); \
171
        av_free((void *) model->name->recurrent_weights); \
172
        av_free((void *) model->name->bias); \
173
        av_free((void *) model->name); \
174
    } \
175
    } while (0)
176
177
    if (!model)
178
        return;
179
    FREE_DENSE(input_dense);
180
    FREE_GRU(vad_gru);
181
    FREE_GRU(noise_gru);
182
    FREE_GRU(denoise_gru);
183
    FREE_DENSE(denoise_output);
184
    FREE_DENSE(vad_output);
185
    av_free(model);
186
}
187
188
static int rnnoise_model_from_file(FILE *f, RNNModel **rnn)
189
{
190
    RNNModel *ret = NULL;
191
    DenseLayer *input_dense;
192
    GRULayer *vad_gru;
193
    GRULayer *noise_gru;
194
    GRULayer *denoise_gru;
195
    DenseLayer *denoise_output;
196
    DenseLayer *vad_output;
197
    int in;
198
199
    if (fscanf(f, "rnnoise-nu model file version %d\n", &in) != 1 || in != 1)
200
        return AVERROR_INVALIDDATA;
201
202
    ret = av_calloc(1, sizeof(RNNModel));
203
    if (!ret)
204
        return AVERROR(ENOMEM);
205
206
#define ALLOC_LAYER(type, name) \
207
    name = av_calloc(1, sizeof(type)); \
208
    if (!name) { \
209
        rnnoise_model_free(ret); \
210
        return AVERROR(ENOMEM); \
211
    } \
212
    ret->name = name
213
214
    ALLOC_LAYER(DenseLayer, input_dense);
215
    ALLOC_LAYER(GRULayer, vad_gru);
216
    ALLOC_LAYER(GRULayer, noise_gru);
217
    ALLOC_LAYER(GRULayer, denoise_gru);
218
    ALLOC_LAYER(DenseLayer, denoise_output);
219
    ALLOC_LAYER(DenseLayer, vad_output);
220
221
#define INPUT_VAL(name) do { \
222
    if (fscanf(f, "%d", &in) != 1 || in < 0 || in > 128) { \
223
        rnnoise_model_free(ret); \
224
        return AVERROR(EINVAL); \
225
    } \
226
    name = in; \
227
    } while (0)
228
229
#define INPUT_ACTIVATION(name) do { \
230
    int activation; \
231
    INPUT_VAL(activation); \
232
    switch (activation) { \
233
    case F_ACTIVATION_SIGMOID: \
234
        name = ACTIVATION_SIGMOID; \
235
        break; \
236
    case F_ACTIVATION_RELU: \
237
        name = ACTIVATION_RELU; \
238
        break; \
239
    default: \
240
        name = ACTIVATION_TANH; \
241
    } \
242
    } while (0)
243
244
#define INPUT_ARRAY(name, len) do { \
245
    float *values = av_calloc((len), sizeof(float)); \
246
    if (!values) { \
247
        rnnoise_model_free(ret); \
248
        return AVERROR(ENOMEM); \
249
    } \
250
    name = values; \
251
    for (int i = 0; i < (len); i++) { \
252
        if (fscanf(f, "%d", &in) != 1) { \
253
            rnnoise_model_free(ret); \
254
            return AVERROR(EINVAL); \
255
        } \
256
        values[i] = in; \
257
    } \
258
    } while (0)
259
260
#define INPUT_ARRAY3(name, len0, len1, len2) do { \
261
    float *values = av_calloc(FFALIGN((len0), 4) * FFALIGN((len1), 4) * (len2), sizeof(float)); \
262
    if (!values) { \
263
        rnnoise_model_free(ret); \
264
        return AVERROR(ENOMEM); \
265
    } \
266
    name = values; \
267
    for (int k = 0; k < (len0); k++) { \
268
        for (int i = 0; i < (len2); i++) { \
269
            for (int j = 0; j < (len1); j++) { \
270
                if (fscanf(f, "%d", &in) != 1) { \
271
                    rnnoise_model_free(ret); \
272
                    return AVERROR(EINVAL); \
273
                } \
274
                values[j * (len2) * FFALIGN((len0), 4) + i * FFALIGN((len0), 4) + k] = in; \
275
            } \
276
        } \
277
    } \
278
    } while (0)
279
280
#define NEW_LINE() do { \
281
    int c; \
282
    while ((c = fgetc(f)) != EOF) { \
283
        if (c == '\n') \
284
        break; \
285
    } \
286
    } while (0)
287
288
#define INPUT_DENSE(name) do { \
289
    INPUT_VAL(name->nb_inputs); \
290
    INPUT_VAL(name->nb_neurons); \
291
    ret->name ## _size = name->nb_neurons; \
292
    INPUT_ACTIVATION(name->activation); \
293
    NEW_LINE(); \
294
    INPUT_ARRAY(name->input_weights, name->nb_inputs * name->nb_neurons); \
295
    NEW_LINE(); \
296
    INPUT_ARRAY(name->bias, name->nb_neurons); \
297
    NEW_LINE(); \
298
    } while (0)
299
300
#define INPUT_GRU(name) do { \
301
    INPUT_VAL(name->nb_inputs); \
302
    INPUT_VAL(name->nb_neurons); \
303
    ret->name ## _size = name->nb_neurons; \
304
    INPUT_ACTIVATION(name->activation); \
305
    NEW_LINE(); \
306
    INPUT_ARRAY3(name->input_weights, name->nb_inputs, name->nb_neurons, 3); \
307
    NEW_LINE(); \
308
    INPUT_ARRAY3(name->recurrent_weights, name->nb_neurons, name->nb_neurons, 3); \
309
    NEW_LINE(); \
310
    INPUT_ARRAY(name->bias, name->nb_neurons * 3); \
311
    NEW_LINE(); \
312
    } while (0)
313
314
    INPUT_DENSE(input_dense);
315
    INPUT_GRU(vad_gru);
316
    INPUT_GRU(noise_gru);
317
    INPUT_GRU(denoise_gru);
318
    INPUT_DENSE(denoise_output);
319
    INPUT_DENSE(vad_output);
320
321
    if (vad_output->nb_neurons != 1) {
322
        rnnoise_model_free(ret);
323
        return AVERROR(EINVAL);
324
    }
325
326
    *rnn = ret;
327
328
    return 0;
329
}
330
331
static int query_formats(AVFilterContext *ctx)
332
{
333
    AVFilterFormats *formats = NULL;
334
    AVFilterChannelLayouts *layouts = NULL;
335
    static const enum AVSampleFormat sample_fmts[] = {
336
        AV_SAMPLE_FMT_FLTP,
337
        AV_SAMPLE_FMT_NONE
338
    };
339
    int ret, sample_rates[] = { 48000, -1 };
340
341
    formats = ff_make_format_list(sample_fmts);
342
    if (!formats)
343
        return AVERROR(ENOMEM);
344
    ret = ff_set_common_formats(ctx, formats);
345
    if (ret < 0)
346
        return ret;
347
348
    layouts = ff_all_channel_counts();
349
    if (!layouts)
350
        return AVERROR(ENOMEM);
351
352
    ret = ff_set_common_channel_layouts(ctx, layouts);
353
    if (ret < 0)
354
        return ret;
355
356
    formats = ff_make_format_list(sample_rates);
357
    if (!formats)
358
        return AVERROR(ENOMEM);
359
    return ff_set_common_samplerates(ctx, formats);
360
}
361
362
static int config_input(AVFilterLink *inlink)
363
{
364
    AVFilterContext *ctx = inlink->dst;
365
    AudioRNNContext *s = ctx->priv;
366
    int ret;
367
368
    s->channels = inlink->channels;
369
370
    if (!s->st)
371
        s->st = av_calloc(s->channels, sizeof(DenoiseState));
372
    if (!s->st)
373
        return AVERROR(ENOMEM);
374
375
    for (int i = 0; i < s->channels; i++) {
376
        DenoiseState *st = &s->st[i];
377
378
        st->rnn[0].model = s->model[0];
379
        st->rnn[0].vad_gru_state = av_calloc(sizeof(float), FFALIGN(s->model[0]->vad_gru_size, 16));
380
        st->rnn[0].noise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model[0]->noise_gru_size, 16));
381
        st->rnn[0].denoise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model[0]->denoise_gru_size, 16));
382
        if (!st->rnn[0].vad_gru_state ||
383
            !st->rnn[0].noise_gru_state ||
384
            !st->rnn[0].denoise_gru_state)
385
            return AVERROR(ENOMEM);
386
    }
387
388
    for (int i = 0; i < s->channels; i++) {
389
        DenoiseState *st = &s->st[i];
390
391
        if (!st->tx)
392
            ret = av_tx_init(&st->tx, &st->tx_fn, AV_TX_FLOAT_FFT, 0, WINDOW_SIZE, NULL, 0);
393
        if (ret < 0)
394
            return ret;
395
396
        if (!st->txi)
397
            ret = av_tx_init(&st->txi, &st->txi_fn, AV_TX_FLOAT_FFT, 1, WINDOW_SIZE, NULL, 0);
398
        if (ret < 0)
399
            return ret;
400
    }
401
402
    return 0;
403
}
404
405
static void biquad(float *y, float mem[2], const float *x,
406
                   const float *b, const float *a, int N)
407
{
408
    for (int i = 0; i < N; i++) {
409
        float xi, yi;
410
411
        xi = x[i];
412
        yi = x[i] + mem[0];
413
        mem[0] = mem[1] + (b[0]*xi - a[0]*yi);
414
        mem[1] = (b[1]*xi - a[1]*yi);
415
        y[i] = yi;
416
    }
417
}
418
419
#define RNN_MOVE(dst, src, n) (memmove((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
420
#define RNN_CLEAR(dst, n) (memset((dst), 0, (n)*sizeof(*(dst))))
421
#define RNN_COPY(dst, src, n) (memcpy((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
422
423
static void forward_transform(DenoiseState *st, AVComplexFloat *out, const float *in)
424
{
425
    AVComplexFloat x[WINDOW_SIZE];
426
    AVComplexFloat y[WINDOW_SIZE];
427
428
    for (int i = 0; i < WINDOW_SIZE; i++) {
429
        x[i].re = in[i];
430
        x[i].im = 0;
431
    }
432
433
    st->tx_fn(st->tx, y, x, sizeof(float));
434
435
    RNN_COPY(out, y, FREQ_SIZE);
436
}
437
438
static void inverse_transform(DenoiseState *st, float *out, const AVComplexFloat *in)
439
{
440
    AVComplexFloat x[WINDOW_SIZE];
441
    AVComplexFloat y[WINDOW_SIZE];
442
443
    RNN_COPY(x, in, FREQ_SIZE);
444
445
    for (int i = FREQ_SIZE; i < WINDOW_SIZE; i++) {
446
        x[i].re =  x[WINDOW_SIZE - i].re;
447
        x[i].im = -x[WINDOW_SIZE - i].im;
448
    }
449
450
    st->txi_fn(st->txi, y, x, sizeof(float));
451
452
    for (int i = 0; i < WINDOW_SIZE; i++)
453
        out[i] = y[i].re / WINDOW_SIZE;
454
}
455
456
static const uint8_t eband5ms[] = {
457
/*0  200 400 600 800  1k 1.2 1.4 1.6  2k 2.4 2.8 3.2  4k 4.8 5.6 6.8  8k 9.6 12k 15.6 20k*/
458
  0,  1,  2,  3,  4,   5, 6,  7,  8,  10, 12, 14, 16, 20, 24, 28, 34, 40, 48, 60, 78, 100
459
};
460
461
static void compute_band_energy(float *bandE, const AVComplexFloat *X)
462
{
463
    float sum[NB_BANDS] = {0};
464
465
    for (int i = 0; i < NB_BANDS - 1; i++) {
466
        int band_size;
467
468
        band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
469
        for (int j = 0; j < band_size; j++) {
470
            float tmp, frac = (float)j / band_size;
471
472
            tmp         = SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].re);
473
            tmp        += SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].im);
474
            sum[i]     += (1.f - frac) * tmp;
475
            sum[i + 1] +=        frac  * tmp;
476
        }
477
    }
478
479
    sum[0] *= 2;
480
    sum[NB_BANDS - 1] *= 2;
481
482
    for (int i = 0; i < NB_BANDS; i++)
483
        bandE[i] = sum[i];
484
}
485
486
static void compute_band_corr(float *bandE, const AVComplexFloat *X, const AVComplexFloat *P)
487
{
488
    float sum[NB_BANDS] = { 0 };
489
490
    for (int i = 0; i < NB_BANDS - 1; i++) {
491
        int band_size;
492
493
        band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
494
        for (int j = 0; j < band_size; j++) {
495
            float tmp, frac = (float)j / band_size;
496
497
            tmp  = X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re;
498
            tmp += X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im;
499
            sum[i]     += (1 - frac) * tmp;
500
            sum[i + 1] +=      frac  * tmp;
501
        }
502
    }
503
504
    sum[0] *= 2;
505
    sum[NB_BANDS-1] *= 2;
506
507
    for (int i = 0; i < NB_BANDS; i++)
508
        bandE[i] = sum[i];
509
}
510
511
static void frame_analysis(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, float *Ex, const float *in)
512
{
513
    LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
514
515
    RNN_COPY(x, st->analysis_mem, FRAME_SIZE);
516
    RNN_COPY(x + FRAME_SIZE, in, FRAME_SIZE);
517
    RNN_COPY(st->analysis_mem, in, FRAME_SIZE);
518
    s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
519
    forward_transform(st, X, x);
520
    compute_band_energy(Ex, X);
521
}
522
523
static void frame_synthesis(AudioRNNContext *s, DenoiseState *st, float *out, const AVComplexFloat *y)
524
{
525
    LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
526
    const float *src = st->history;
527
    const float mix = s->mix;
528
    const float imix = 1.f - FFMAX(mix, 0.f);
529
530
    inverse_transform(st, x, y);
531
    s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
532
    s->fdsp->vector_fmac_scalar(x, st->synthesis_mem, 1.f, FRAME_SIZE);
533
    RNN_COPY(out, x, FRAME_SIZE);
534
    RNN_COPY(st->synthesis_mem, &x[FRAME_SIZE], FRAME_SIZE);
535
536
    for (int n = 0; n < FRAME_SIZE; n++)
537
        out[n] = out[n] * mix + src[n] * imix;
538
}
539
540
static inline void xcorr_kernel(const float *x, const float *y, float sum[4], int len)
541
{
542
    float y_0, y_1, y_2, y_3 = 0;
543
    int j;
544
545
    y_0 = *y++;
546
    y_1 = *y++;
547
    y_2 = *y++;
548
549
    for (j = 0; j < len - 3; j += 4) {
550
        float tmp;
551
552
        tmp = *x++;
553
        y_3 = *y++;
554
        sum[0] += tmp * y_0;
555
        sum[1] += tmp * y_1;
556
        sum[2] += tmp * y_2;
557
        sum[3] += tmp * y_3;
558
        tmp = *x++;
559
        y_0 = *y++;
560
        sum[0] += tmp * y_1;
561
        sum[1] += tmp * y_2;
562
        sum[2] += tmp * y_3;
563
        sum[3] += tmp * y_0;
564
        tmp = *x++;
565
        y_1 = *y++;
566
        sum[0] += tmp * y_2;
567
        sum[1] += tmp * y_3;
568
        sum[2] += tmp * y_0;
569
        sum[3] += tmp * y_1;
570
        tmp = *x++;
571
        y_2 = *y++;
572
        sum[0] += tmp * y_3;
573
        sum[1] += tmp * y_0;
574
        sum[2] += tmp * y_1;
575
        sum[3] += tmp * y_2;
576
    }
577
578
    if (j++ < len) {
579
        float tmp = *x++;
580
581
        y_3 = *y++;
582
        sum[0] += tmp * y_0;
583
        sum[1] += tmp * y_1;
584
        sum[2] += tmp * y_2;
585
        sum[3] += tmp * y_3;
586
    }
587
588
    if (j++ < len) {
589
        float tmp=*x++;
590
591
        y_0 = *y++;
592
        sum[0] += tmp * y_1;
593
        sum[1] += tmp * y_2;
594
        sum[2] += tmp * y_3;
595
        sum[3] += tmp * y_0;
596
    }
597
598
    if (j < len) {
599
        float tmp=*x++;
600
601
        y_1 = *y++;
602
        sum[0] += tmp * y_2;
603
        sum[1] += tmp * y_3;
604
        sum[2] += tmp * y_0;
605
        sum[3] += tmp * y_1;
606
    }
607
}
608
609
static inline float celt_inner_prod(const float *x,
610
                                    const float *y, int N)
611
{
612
    float xy = 0.f;
613
614
    for (int i = 0; i < N; i++)
615
        xy += x[i] * y[i];
616
617
    return xy;
618
}
619
620
static void celt_pitch_xcorr(const float *x, const float *y,
621
                             float *xcorr, int len, int max_pitch)
622
{
623
    int i;
624
625
    for (i = 0; i < max_pitch - 3; i += 4) {
626
        float sum[4] = { 0, 0, 0, 0};
627
628
        xcorr_kernel(x, y + i, sum, len);
629
630
        xcorr[i]     = sum[0];
631
        xcorr[i + 1] = sum[1];
632
        xcorr[i + 2] = sum[2];
633
        xcorr[i + 3] = sum[3];
634
    }
635
    /* In case max_pitch isn't a multiple of 4, do non-unrolled version. */
636
    for (; i < max_pitch; i++) {
637
        xcorr[i] = celt_inner_prod(x, y + i, len);
638
    }
639
}
640
641
static int celt_autocorr(const float *x,   /*  in: [0...n-1] samples x   */
642
                         float       *ac,  /* out: [0...lag-1] ac values */
643
                         const float *window,
644
                         int          overlap,
645
                         int          lag,
646
                         int          n)
647
{
648
    int fastN = n - lag;
649
    int shift;
650
    const float *xptr;
651
    float xx[PITCH_BUF_SIZE>>1];
652
653
    if (overlap == 0) {
654
        xptr = x;
655
    } else {
656
        for (int i = 0; i < n; i++)
657
            xx[i] = x[i];
658
        for (int i = 0; i < overlap; i++) {
659
            xx[i] = x[i] * window[i];
660
            xx[n-i-1] = x[n-i-1] * window[i];
661
        }
662
        xptr = xx;
663
    }
664
665
    shift = 0;
666
    celt_pitch_xcorr(xptr, xptr, ac, fastN, lag+1);
667
668
    for (int k = 0; k <= lag; k++) {
669
        float d = 0.f;
670
671
        for (int i = k + fastN; i < n; i++)
672
            d += xptr[i] * xptr[i-k];
673
        ac[k] += d;
674
    }
675
676
    return shift;
677
}
678
679
static void celt_lpc(float *lpc, /* out: [0...p-1] LPC coefficients      */
680
                const float *ac,   /* in:  [0...p] autocorrelation values  */
681
                          int p)
682
{
683
    float r, error = ac[0];
684
685
    RNN_CLEAR(lpc, p);
686
    if (ac[0] != 0) {
687
        for (int i = 0; i < p; i++) {
688
            /* Sum up this iteration's reflection coefficient */
689
            float rr = 0;
690
            for (int j = 0; j < i; j++)
691
                rr += (lpc[j] * ac[i - j]);
692
            rr += ac[i + 1];
693
            r = -rr/error;
694
            /*  Update LPC coefficients and total error */
695
            lpc[i] = r;
696
            for (int j = 0; j < (i + 1) >> 1; j++) {
697
                float tmp1, tmp2;
698
                tmp1 = lpc[j];
699
                tmp2 = lpc[i-1-j];
700
                lpc[j]     = tmp1 + (r*tmp2);
701
                lpc[i-1-j] = tmp2 + (r*tmp1);
702
            }
703
704
            error = error - (r * r *error);
705
            /* Bail out once we get 30 dB gain */
706
            if (error < .001f * ac[0])
707
                break;
708
        }
709
    }
710
}
711
712
static void celt_fir5(const float *x,
713
                      const float *num,
714
                      float *y,
715
                      int N,
716
                      float *mem)
717
{
718
    float num0, num1, num2, num3, num4;
719
    float mem0, mem1, mem2, mem3, mem4;
720
721
    num0 = num[0];
722
    num1 = num[1];
723
    num2 = num[2];
724
    num3 = num[3];
725
    num4 = num[4];
726
    mem0 = mem[0];
727
    mem1 = mem[1];
728
    mem2 = mem[2];
729
    mem3 = mem[3];
730
    mem4 = mem[4];
731
732
    for (int i = 0; i < N; i++) {
733
        float sum = x[i];
734
735
        sum += (num0*mem0);
736
        sum += (num1*mem1);
737
        sum += (num2*mem2);
738
        sum += (num3*mem3);
739
        sum += (num4*mem4);
740
        mem4 = mem3;
741
        mem3 = mem2;
742
        mem2 = mem1;
743
        mem1 = mem0;
744
        mem0 = x[i];
745
        y[i] = sum;
746
    }
747
748
    mem[0] = mem0;
749
    mem[1] = mem1;
750
    mem[2] = mem2;
751
    mem[3] = mem3;
752
    mem[4] = mem4;
753
}
754
755
static void pitch_downsample(float *x[], float *x_lp,
756
                             int len, int C)
757
{
758
    float ac[5];
759
    float tmp=Q15ONE;
760
    float lpc[4], mem[5]={0,0,0,0,0};
761
    float lpc2[5];
762
    float c1 = .8f;
763
764
    for (int i = 1; i < len >> 1; i++)
765
        x_lp[i] = .5f * (.5f * (x[0][(2*i-1)]+x[0][(2*i+1)])+x[0][2*i]);
766
    x_lp[0] = .5f * (.5f * (x[0][1])+x[0][0]);
767
    if (C==2) {
768
        for (int i = 1; i < len >> 1; i++)
769
            x_lp[i] += (.5f * (.5f * (x[1][(2*i-1)]+x[1][(2*i+1)])+x[1][2*i]));
770
        x_lp[0] += .5f * (.5f * (x[1][1])+x[1][0]);
771
    }
772
773
    celt_autocorr(x_lp, ac, NULL, 0, 4, len>>1);
774
775
    /* Noise floor -40 dB */
776
    ac[0] *= 1.0001f;
777
    /* Lag windowing */
778
    for (int i = 1; i <= 4; i++) {
779
        /*ac[i] *= exp(-.5*(2*M_PI*.002*i)*(2*M_PI*.002*i));*/
780
        ac[i] -= ac[i]*(.008f*i)*(.008f*i);
781
    }
782
783
    celt_lpc(lpc, ac, 4);
784
    for (int i = 0; i < 4; i++) {
785
        tmp = .9f * tmp;
786
        lpc[i] = (lpc[i] * tmp);
787
    }
788
    /* Add a zero */
789
    lpc2[0] = lpc[0] + .8f;
790
    lpc2[1] = lpc[1] + (c1 * lpc[0]);
791
    lpc2[2] = lpc[2] + (c1 * lpc[1]);
792
    lpc2[3] = lpc[3] + (c1 * lpc[2]);
793
    lpc2[4] = (c1 * lpc[3]);
794
    celt_fir5(x_lp, lpc2, x_lp, len>>1, mem);
795
}
796
797
static inline void dual_inner_prod(const float *x, const float *y01, const float *y02,
798
                                   int N, float *xy1, float *xy2)
799
{
800
    float xy01 = 0, xy02 = 0;
801
802
    for (int i = 0; i < N; i++) {
803
        xy01 += (x[i] * y01[i]);
804
        xy02 += (x[i] * y02[i]);
805
    }
806
807
    *xy1 = xy01;
808
    *xy2 = xy02;
809
}
810
811
static float compute_pitch_gain(float xy, float xx, float yy)
812
{
813
    return xy / sqrtf(1.f + xx * yy);
814
}
815
816
static const uint8_t second_check[16] = {0, 0, 3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2};
817
static float remove_doubling(float *x, int maxperiod, int minperiod, int N,
818
                             int *T0_, int prev_period, float prev_gain)
819
{
820
    int k, i, T, T0;
821
    float g, g0;
822
    float pg;
823
    float xy,xx,yy,xy2;
824
    float xcorr[3];
825
    float best_xy, best_yy;
826
    int offset;
827
    int minperiod0;
828
    float yy_lookup[PITCH_MAX_PERIOD+1];
829
830
    minperiod0 = minperiod;
831
    maxperiod /= 2;
832
    minperiod /= 2;
833
    *T0_ /= 2;
834
    prev_period /= 2;
835
    N /= 2;
836
    x += maxperiod;
837
    if (*T0_>=maxperiod)
838
        *T0_=maxperiod-1;
839
840
    T = T0 = *T0_;
841
    dual_inner_prod(x, x, x-T0, N, &xx, &xy);
842
    yy_lookup[0] = xx;
843
    yy=xx;
844
    for (i = 1; i <= maxperiod; i++) {
845
        yy = yy+(x[-i] * x[-i])-(x[N-i] * x[N-i]);
846
        yy_lookup[i] = FFMAX(0, yy);
847
    }
848
    yy = yy_lookup[T0];
849
    best_xy = xy;
850
    best_yy = yy;
851
    g = g0 = compute_pitch_gain(xy, xx, yy);
852
    /* Look for any pitch at T/k */
853
    for (k = 2; k <= 15; k++) {
854
        int T1, T1b;
855
        float g1;
856
        float cont=0;
857
        float thresh;
858
        T1 = (2*T0+k)/(2*k);
859
        if (T1 < minperiod)
860
            break;
861
        /* Look for another strong correlation at T1b */
862
        if (k==2)
863
        {
864
            if (T1+T0>maxperiod)
865
                T1b = T0;
866
            else
867
                T1b = T0+T1;
868
        } else
869
        {
870
            T1b = (2*second_check[k]*T0+k)/(2*k);
871
        }
872
        dual_inner_prod(x, &x[-T1], &x[-T1b], N, &xy, &xy2);
873
        xy = .5f * (xy + xy2);
874
        yy = .5f * (yy_lookup[T1] + yy_lookup[T1b]);
875
        g1 = compute_pitch_gain(xy, xx, yy);
876
        if (FFABS(T1-prev_period)<=1)
877
            cont = prev_gain;
878
        else if (FFABS(T1-prev_period)<=2 && 5 * k * k < T0)
879
            cont = prev_gain * .5f;
880
        else
881
            cont = 0;
882
        thresh = FFMAX(.3f, (.7f * g0) - cont);
883
        /* Bias against very high pitch (very short period) to avoid false-positives
884
           due to short-term correlation */
885
        if (T1<3*minperiod)
886
            thresh = FFMAX(.4f, (.85f * g0) - cont);
887
        else if (T1<2*minperiod)
888
            thresh = FFMAX(.5f, (.9f * g0) - cont);
889
        if (g1 > thresh)
890
        {
891
            best_xy = xy;
892
            best_yy = yy;
893
            T = T1;
894
            g = g1;
895
        }
896
    }
897
    best_xy = FFMAX(0, best_xy);
898
    if (best_yy <= best_xy)
899
        pg = Q15ONE;
900
    else
901
        pg = best_xy/(best_yy + 1);
902
903
    for (k = 0; k < 3; k++)
904
        xcorr[k] = celt_inner_prod(x, x-(T+k-1), N);
905
    if ((xcorr[2]-xcorr[0]) > .7f * (xcorr[1]-xcorr[0]))
906
        offset = 1;
907
    else if ((xcorr[0]-xcorr[2]) > (.7f * (xcorr[1] - xcorr[2])))
908
        offset = -1;
909
    else
910
        offset = 0;
911
    if (pg > g)
912
        pg = g;
913
    *T0_ = 2*T+offset;
914
915
    if (*T0_<minperiod0)
916
        *T0_=minperiod0;
917
    return pg;
918
}
919
920
static void find_best_pitch(float *xcorr, float *y, int len,
921
                            int max_pitch, int *best_pitch)
922
{
923
    float best_num[2];
924
    float best_den[2];
925
    float Syy = 1.f;
926
927
    best_num[0] = -1;
928
    best_num[1] = -1;
929
    best_den[0] = 0;
930
    best_den[1] = 0;
931
    best_pitch[0] = 0;
932
    best_pitch[1] = 1;
933
934
    for (int j = 0; j < len; j++)
935
        Syy += y[j] * y[j];
936
937
    for (int i = 0; i < max_pitch; i++) {
938
        if (xcorr[i]>0) {
939
            float num;
940
            float xcorr16;
941
942
            xcorr16 = xcorr[i];
943
            /* Considering the range of xcorr16, this should avoid both underflows
944
               and overflows (inf) when squaring xcorr16 */
945
            xcorr16 *= 1e-12f;
946
            num = xcorr16 * xcorr16;
947
            if ((num * best_den[1]) > (best_num[1] * Syy)) {
948
                if ((num * best_den[0]) > (best_num[0] * Syy)) {
949
                    best_num[1] = best_num[0];
950
                    best_den[1] = best_den[0];
951
                    best_pitch[1] = best_pitch[0];
952
                    best_num[0] = num;
953
                    best_den[0] = Syy;
954
                    best_pitch[0] = i;
955
                } else {
956
                    best_num[1] = num;
957
                    best_den[1] = Syy;
958
                    best_pitch[1] = i;
959
                }
960
            }
961
        }
962
        Syy += y[i+len]*y[i+len] - y[i] * y[i];
963
        Syy = FFMAX(1, Syy);
964
    }
965
}
966
967
static void pitch_search(const float *x_lp, float *y,
968
                         int len, int max_pitch, int *pitch)
969
{
970
    int lag;
971
    int best_pitch[2]={0,0};
972
    int offset;
973
974
    float x_lp4[WINDOW_SIZE];
975
    float y_lp4[WINDOW_SIZE];
976
    float xcorr[WINDOW_SIZE];
977
978
    lag = len+max_pitch;
979
980
    /* Downsample by 2 again */
981
    for (int j = 0; j < len >> 2; j++)
982
        x_lp4[j] = x_lp[2*j];
983
    for (int j = 0; j < lag >> 2; j++)
984
        y_lp4[j] = y[2*j];
985
986
    /* Coarse search with 4x decimation */
987
988
    celt_pitch_xcorr(x_lp4, y_lp4, xcorr, len>>2, max_pitch>>2);
989
990
    find_best_pitch(xcorr, y_lp4, len>>2, max_pitch>>2, best_pitch);
991
992
    /* Finer search with 2x decimation */
993
    for (int i = 0; i < max_pitch >> 1; i++) {
994
        float sum;
995
        xcorr[i] = 0;
996
        if (FFABS(i-2*best_pitch[0])>2 && FFABS(i-2*best_pitch[1])>2)
997
            continue;
998
        sum = celt_inner_prod(x_lp, y+i, len>>1);
999
        xcorr[i] = FFMAX(-1, sum);
1000
    }
1001
1002
    find_best_pitch(xcorr, y, len>>1, max_pitch>>1, best_pitch);
1003
1004
    /* Refine by pseudo-interpolation */
1005
    if (best_pitch[0] > 0 && best_pitch[0] < (max_pitch >> 1) - 1) {
1006
        float a, b, c;
1007
1008
        a = xcorr[best_pitch[0] - 1];
1009
        b = xcorr[best_pitch[0]];
1010
        c = xcorr[best_pitch[0] + 1];
1011
        if (c - a > .7f * (b - a))
1012
            offset = 1;
1013
        else if (a - c > .7f * (b-c))
1014
            offset = -1;
1015
        else
1016
            offset = 0;
1017
    } else {
1018
        offset = 0;
1019
    }
1020
1021
    *pitch = 2 * best_pitch[0] - offset;
1022
}
1023
1024
static void dct(AudioRNNContext *s, float *out, const float *in)
1025
{
1026
    for (int i = 0; i < NB_BANDS; i++) {
1027
        float sum;
1028
1029
        sum = s->fdsp->scalarproduct_float(in, s->dct_table[i], FFALIGN(NB_BANDS, 4));
1030
        out[i] = sum * sqrtf(2.f / 22);
1031
    }
1032
}
1033
1034
static int compute_frame_features(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, AVComplexFloat *P,
1035
                                  float *Ex, float *Ep, float *Exp, float *features, const float *in)
1036
{
1037
    float E = 0;
1038
    float *ceps_0, *ceps_1, *ceps_2;
1039
    float spec_variability = 0;
1040
    LOCAL_ALIGNED_32(float, Ly, [NB_BANDS]);
1041
    LOCAL_ALIGNED_32(float, p, [WINDOW_SIZE]);
1042
    float pitch_buf[PITCH_BUF_SIZE>>1];
1043
    int pitch_index;
1044
    float gain;
1045
    float *(pre[1]);
1046
    float tmp[NB_BANDS];
1047
    float follow, logMax;
1048
1049
    frame_analysis(s, st, X, Ex, in);
1050
    RNN_MOVE(st->pitch_buf, &st->pitch_buf[FRAME_SIZE], PITCH_BUF_SIZE-FRAME_SIZE);
1051
    RNN_COPY(&st->pitch_buf[PITCH_BUF_SIZE-FRAME_SIZE], in, FRAME_SIZE);
1052
    pre[0] = &st->pitch_buf[0];
1053
    pitch_downsample(pre, pitch_buf, PITCH_BUF_SIZE, 1);
1054
    pitch_search(pitch_buf+(PITCH_MAX_PERIOD>>1), pitch_buf, PITCH_FRAME_SIZE,
1055
            PITCH_MAX_PERIOD-3*PITCH_MIN_PERIOD, &pitch_index);
1056
    pitch_index = PITCH_MAX_PERIOD-pitch_index;
1057
1058
    gain = remove_doubling(pitch_buf, PITCH_MAX_PERIOD, PITCH_MIN_PERIOD,
1059
            PITCH_FRAME_SIZE, &pitch_index, st->last_period, st->last_gain);
1060
    st->last_period = pitch_index;
1061
    st->last_gain = gain;
1062
1063
    for (int i = 0; i < WINDOW_SIZE; i++)
1064
        p[i] = st->pitch_buf[PITCH_BUF_SIZE-WINDOW_SIZE-pitch_index+i];
1065
1066
    s->fdsp->vector_fmul(p, p, s->window, WINDOW_SIZE);
1067
    forward_transform(st, P, p);
1068
    compute_band_energy(Ep, P);
1069
    compute_band_corr(Exp, X, P);
1070
1071
    for (int i = 0; i < NB_BANDS; i++)
1072
        Exp[i] = Exp[i] / sqrtf(.001f+Ex[i]*Ep[i]);
1073
1074
    dct(s, tmp, Exp);
1075
1076
    for (int i = 0; i < NB_DELTA_CEPS; i++)
1077
        features[NB_BANDS+2*NB_DELTA_CEPS+i] = tmp[i];
1078
1079
    features[NB_BANDS+2*NB_DELTA_CEPS] -= 1.3;
1080
    features[NB_BANDS+2*NB_DELTA_CEPS+1] -= 0.9;
1081
    features[NB_BANDS+3*NB_DELTA_CEPS] = .01*(pitch_index-300);
1082
    logMax = -2;
1083
    follow = -2;
1084
1085
    for (int i = 0; i < NB_BANDS; i++) {
1086
        Ly[i] = log10f(1e-2f + Ex[i]);
1087
        Ly[i] = FFMAX(logMax-7, FFMAX(follow-1.5, Ly[i]));
1088
        logMax = FFMAX(logMax, Ly[i]);
1089
        follow = FFMAX(follow-1.5, Ly[i]);
1090
        E += Ex[i];
1091
    }
1092
1093
    if (E < 0.04f) {
1094
        /* If there's no audio, avoid messing up the state. */
1095
        RNN_CLEAR(features, NB_FEATURES);
1096
        return 1;
1097
    }
1098
1099
    dct(s, features, Ly);
1100
    features[0] -= 12;
1101
    features[1] -= 4;
1102
    ceps_0 = st->cepstral_mem[st->memid];
1103
    ceps_1 = (st->memid < 1) ? st->cepstral_mem[CEPS_MEM+st->memid-1] : st->cepstral_mem[st->memid-1];
1104
    ceps_2 = (st->memid < 2) ? st->cepstral_mem[CEPS_MEM+st->memid-2] : st->cepstral_mem[st->memid-2];
1105
1106
    for (int i = 0; i < NB_BANDS; i++)
1107
        ceps_0[i] = features[i];
1108
1109
    st->memid++;
1110
    for (int i = 0; i < NB_DELTA_CEPS; i++) {
1111
        features[i] = ceps_0[i] + ceps_1[i] + ceps_2[i];
1112
        features[NB_BANDS+i] = ceps_0[i] - ceps_2[i];
1113
        features[NB_BANDS+NB_DELTA_CEPS+i] =  ceps_0[i] - 2*ceps_1[i] + ceps_2[i];
1114
    }
1115
    /* Spectral variability features. */
1116
    if (st->memid == CEPS_MEM)
1117
        st->memid = 0;
1118
1119
    for (int i = 0; i < CEPS_MEM; i++) {
1120
        float mindist = 1e15f;
1121
        for (int j = 0; j < CEPS_MEM; j++) {
1122
            float dist = 0.f;
1123
            for (int k = 0; k < NB_BANDS; k++) {
1124
                float tmp;
1125
1126
                tmp = st->cepstral_mem[i][k] - st->cepstral_mem[j][k];
1127
                dist += tmp*tmp;
1128
            }
1129
1130
            if (j != i)
1131
                mindist = FFMIN(mindist, dist);
1132
        }
1133
1134
        spec_variability += mindist;
1135
    }
1136
1137
    features[NB_BANDS+3*NB_DELTA_CEPS+1] = spec_variability/CEPS_MEM-2.1;
1138
1139
    return 0;
1140
}
1141
1142
static void interp_band_gain(float *g, const float *bandE)
1143
{
1144
    memset(g, 0, sizeof(*g) * FREQ_SIZE);
1145
1146
    for (int i = 0; i < NB_BANDS - 1; i++) {
1147
        const int band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
1148
1149
        for (int j = 0; j < band_size; j++) {
1150
            float frac = (float)j / band_size;
1151
1152
            g[(eband5ms[i] << FRAME_SIZE_SHIFT) + j] = (1.f - frac) * bandE[i] + frac * bandE[i + 1];
1153
        }
1154
    }
1155
}
1156
1157
static void pitch_filter(AVComplexFloat *X, const AVComplexFloat *P, const float *Ex, const float *Ep,
1158
                         const float *Exp, const float *g)
1159
{
1160
    float newE[NB_BANDS];
1161
    float r[NB_BANDS];
1162
    float norm[NB_BANDS];
1163
    float rf[FREQ_SIZE] = {0};
1164
    float normf[FREQ_SIZE]={0};
1165
1166
    for (int i = 0; i < NB_BANDS; i++) {
1167
        if (Exp[i]>g[i]) r[i] = 1;
1168
        else r[i] = SQUARE(Exp[i])*(1-SQUARE(g[i]))/(.001 + SQUARE(g[i])*(1-SQUARE(Exp[i])));
1169
        r[i]  = sqrtf(av_clipf(r[i], 0, 1));
1170
        r[i] *= sqrtf(Ex[i]/(1e-8+Ep[i]));
1171
    }
1172
    interp_band_gain(rf, r);
1173
    for (int i = 0; i < FREQ_SIZE; i++) {
1174
        X[i].re += rf[i]*P[i].re;
1175
        X[i].im += rf[i]*P[i].im;
1176
    }
1177
    compute_band_energy(newE, X);
1178
    for (int i = 0; i < NB_BANDS; i++) {
1179
        norm[i] = sqrtf(Ex[i] / (1e-8+newE[i]));
1180
    }
1181
    interp_band_gain(normf, norm);
1182
    for (int i = 0; i < FREQ_SIZE; i++) {
1183
        X[i].re *= normf[i];
1184
        X[i].im *= normf[i];
1185
    }
1186
}
1187
1188
static const float tansig_table[201] = {
1189
    0.000000f, 0.039979f, 0.079830f, 0.119427f, 0.158649f,
1190
    0.197375f, 0.235496f, 0.272905f, 0.309507f, 0.345214f,
1191
    0.379949f, 0.413644f, 0.446244f, 0.477700f, 0.507977f,
1192
    0.537050f, 0.564900f, 0.591519f, 0.616909f, 0.641077f,
1193
    0.664037f, 0.685809f, 0.706419f, 0.725897f, 0.744277f,
1194
    0.761594f, 0.777888f, 0.793199f, 0.807569f, 0.821040f,
1195
    0.833655f, 0.845456f, 0.856485f, 0.866784f, 0.876393f,
1196
    0.885352f, 0.893698f, 0.901468f, 0.908698f, 0.915420f,
1197
    0.921669f, 0.927473f, 0.932862f, 0.937863f, 0.942503f,
1198
    0.946806f, 0.950795f, 0.954492f, 0.957917f, 0.961090f,
1199
    0.964028f, 0.966747f, 0.969265f, 0.971594f, 0.973749f,
1200
    0.975743f, 0.977587f, 0.979293f, 0.980869f, 0.982327f,
1201
    0.983675f, 0.984921f, 0.986072f, 0.987136f, 0.988119f,
1202
    0.989027f, 0.989867f, 0.990642f, 0.991359f, 0.992020f,
1203
    0.992631f, 0.993196f, 0.993718f, 0.994199f, 0.994644f,
1204
    0.995055f, 0.995434f, 0.995784f, 0.996108f, 0.996407f,
1205
    0.996682f, 0.996937f, 0.997172f, 0.997389f, 0.997590f,
1206
    0.997775f, 0.997946f, 0.998104f, 0.998249f, 0.998384f,
1207
    0.998508f, 0.998623f, 0.998728f, 0.998826f, 0.998916f,
1208
    0.999000f, 0.999076f, 0.999147f, 0.999213f, 0.999273f,
1209
    0.999329f, 0.999381f, 0.999428f, 0.999472f, 0.999513f,
1210
    0.999550f, 0.999585f, 0.999617f, 0.999646f, 0.999673f,
1211
    0.999699f, 0.999722f, 0.999743f, 0.999763f, 0.999781f,
1212
    0.999798f, 0.999813f, 0.999828f, 0.999841f, 0.999853f,
1213
    0.999865f, 0.999875f, 0.999885f, 0.999893f, 0.999902f,
1214
    0.999909f, 0.999916f, 0.999923f, 0.999929f, 0.999934f,
1215
    0.999939f, 0.999944f, 0.999948f, 0.999952f, 0.999956f,
1216
    0.999959f, 0.999962f, 0.999965f, 0.999968f, 0.999970f,
1217
    0.999973f, 0.999975f, 0.999977f, 0.999978f, 0.999980f,
1218
    0.999982f, 0.999983f, 0.999984f, 0.999986f, 0.999987f,
1219
    0.999988f, 0.999989f, 0.999990f, 0.999990f, 0.999991f,
1220
    0.999992f, 0.999992f, 0.999993f, 0.999994f, 0.999994f,
1221
    0.999994f, 0.999995f, 0.999995f, 0.999996f, 0.999996f,
1222
    0.999996f, 0.999997f, 0.999997f, 0.999997f, 0.999997f,
1223
    0.999997f, 0.999998f, 0.999998f, 0.999998f, 0.999998f,
1224
    0.999998f, 0.999998f, 0.999999f, 0.999999f, 0.999999f,
1225
    0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1226
    0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1227
    1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1228
    1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1229
    1.000000f,
1230
};
1231
1232
static inline float tansig_approx(float x)
1233
{
1234
    float y, dy;
1235
    float sign=1;
1236
    int i;
1237
1238
    /* Tests are reversed to catch NaNs */
1239
    if (!(x<8))
1240
        return 1;
1241
    if (!(x>-8))
1242
        return -1;
1243
    /* Another check in case of -ffast-math */
1244
1245
    if (isnan(x))
1246
       return 0;
1247
1248
    if (x < 0) {
1249
       x=-x;
1250
       sign=-1;
1251
    }
1252
    i = (int)floor(.5f+25*x);
1253
    x -= .04f*i;
1254
    y = tansig_table[i];
1255
    dy = 1-y*y;
1256
    y = y + x*dy*(1 - y*x);
1257
    return sign*y;
1258
}
1259
1260
static inline float sigmoid_approx(float x)
1261
{
1262
    return .5f + .5f*tansig_approx(.5f*x);
1263
}
1264
1265
static void compute_dense(const DenseLayer *layer, float *output, const float *input)
1266
{
1267
    const int N = layer->nb_neurons, M = layer->nb_inputs, stride = N;
1268
1269
    for (int i = 0; i < N; i++) {
1270
        /* Compute update gate. */
1271
        float sum = layer->bias[i];
1272
1273
        for (int j = 0; j < M; j++)
1274
            sum += layer->input_weights[j * stride + i] * input[j];
1275
1276
        output[i] = WEIGHTS_SCALE * sum;
1277
    }
1278
1279
    if (layer->activation == ACTIVATION_SIGMOID) {
1280
        for (int i = 0; i < N; i++)
1281
            output[i] = sigmoid_approx(output[i]);
1282
    } else if (layer->activation == ACTIVATION_TANH) {
1283
        for (int i = 0; i < N; i++)
1284
            output[i] = tansig_approx(output[i]);
1285
    } else if (layer->activation == ACTIVATION_RELU) {
1286
        for (int i = 0; i < N; i++)
1287
            output[i] = FFMAX(0, output[i]);
1288
    } else {
1289
        av_assert0(0);
1290
    }
1291
}
1292
1293
static void compute_gru(AudioRNNContext *s, const GRULayer *gru, float *state, const float *input)
1294
{
1295
    LOCAL_ALIGNED_32(float, z, [MAX_NEURONS]);
1296
    LOCAL_ALIGNED_32(float, r, [MAX_NEURONS]);
1297
    LOCAL_ALIGNED_32(float, h, [MAX_NEURONS]);
1298
    const int M = gru->nb_inputs;
1299
    const int N = gru->nb_neurons;
1300
    const int AN = FFALIGN(N, 4);
1301
    const int AM = FFALIGN(M, 4);
1302
    const int stride = 3 * AN, istride = 3 * AM;
1303
1304
    for (int i = 0; i < N; i++) {
1305
        /* Compute update gate. */
1306
        float sum = gru->bias[i];
1307
1308
        sum += s->fdsp->scalarproduct_float(gru->input_weights + i * istride, input, AM);
1309
        sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + i * stride, state, AN);
1310
        z[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1311
    }
1312
1313
    for (int i = 0; i < N; i++) {
1314
        /* Compute reset gate. */
1315
        float sum = gru->bias[N + i];
1316
1317
        sum += s->fdsp->scalarproduct_float(gru->input_weights + AM + i * istride, input, AM);
1318
        sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + AN + i * stride, state, AN);
1319
        r[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1320
    }
1321
1322
    for (int i = 0; i < N; i++) {
1323
        /* Compute output. */
1324
        float sum = gru->bias[2 * N + i];
1325
1326
        sum += s->fdsp->scalarproduct_float(gru->input_weights + 2 * AM + i * istride, input, AM);
1327
        for (int j = 0; j < N; j++)
1328
            sum += gru->recurrent_weights[2 * AN + i * stride + j] * state[j] * r[j];
1329
1330
        if (gru->activation == ACTIVATION_SIGMOID)
1331
            sum = sigmoid_approx(WEIGHTS_SCALE * sum);
1332
        else if (gru->activation == ACTIVATION_TANH)
1333
            sum = tansig_approx(WEIGHTS_SCALE * sum);
1334
        else if (gru->activation == ACTIVATION_RELU)
1335
            sum = FFMAX(0, WEIGHTS_SCALE * sum);
1336
        else
1337
            av_assert0(0);
1338
        h[i] = z[i] * state[i] + (1.f - z[i]) * sum;
1339
    }
1340
1341
    RNN_COPY(state, h, N);
1342
}
1343
1344
#define INPUT_SIZE 42
1345
1346
static void compute_rnn(AudioRNNContext *s, RNNState *rnn, float *gains, float *vad, const float *input)
1347
{
1348
    LOCAL_ALIGNED_32(float, dense_out,     [MAX_NEURONS]);
1349
    LOCAL_ALIGNED_32(float, noise_input,   [MAX_NEURONS * 3]);
1350
    LOCAL_ALIGNED_32(float, denoise_input, [MAX_NEURONS * 3]);
1351
1352
    compute_dense(rnn->model->input_dense, dense_out, input);
1353
    compute_gru(s, rnn->model->vad_gru, rnn->vad_gru_state, dense_out);
1354
    compute_dense(rnn->model->vad_output, vad, rnn->vad_gru_state);
1355
1356
    memcpy(noise_input, dense_out, rnn->model->input_dense_size * sizeof(float));
1357
    memcpy(noise_input + rnn->model->input_dense_size,
1358
           rnn->vad_gru_state, rnn->model->vad_gru_size * sizeof(float));
1359
    memcpy(noise_input + rnn->model->input_dense_size + rnn->model->vad_gru_size,
1360
           input, INPUT_SIZE * sizeof(float));
1361
1362
    compute_gru(s, rnn->model->noise_gru, rnn->noise_gru_state, noise_input);
1363
1364
    memcpy(denoise_input, rnn->vad_gru_state, rnn->model->vad_gru_size * sizeof(float));
1365
    memcpy(denoise_input + rnn->model->vad_gru_size,
1366
           rnn->noise_gru_state, rnn->model->noise_gru_size * sizeof(float));
1367
    memcpy(denoise_input + rnn->model->vad_gru_size + rnn->model->noise_gru_size,
1368
           input, INPUT_SIZE * sizeof(float));
1369
1370
    compute_gru(s, rnn->model->denoise_gru, rnn->denoise_gru_state, denoise_input);
1371
    compute_dense(rnn->model->denoise_output, gains, rnn->denoise_gru_state);
1372
}
1373
1374
static float rnnoise_channel(AudioRNNContext *s, DenoiseState *st, float *out, const float *in,
1375
                             int disabled)
1376
{
1377
    AVComplexFloat X[FREQ_SIZE];
1378
    AVComplexFloat P[WINDOW_SIZE];
1379
    float x[FRAME_SIZE];
1380
    float Ex[NB_BANDS], Ep[NB_BANDS];
1381
    LOCAL_ALIGNED_32(float, Exp, [NB_BANDS]);
1382
    float features[NB_FEATURES];
1383
    float g[NB_BANDS];
1384
    float gf[FREQ_SIZE];
1385
    float vad_prob = 0;
1386
    float *history = st->history;
1387
    static const float a_hp[2] = {-1.99599, 0.99600};
1388
    static const float b_hp[2] = {-2, 1};
1389
    int silence;
1390
1391
    biquad(x, st->mem_hp_x, in, b_hp, a_hp, FRAME_SIZE);
1392
    silence = compute_frame_features(s, st, X, P, Ex, Ep, Exp, features, x);
1393
1394
    if (!silence && !disabled) {
1395
        compute_rnn(s, &st->rnn[0], g, &vad_prob, features);
1396
        pitch_filter(X, P, Ex, Ep, Exp, g);
1397
        for (int i = 0; i < NB_BANDS; i++) {
1398
            float alpha = .6f;
1399
1400
            g[i] = FFMAX(g[i], alpha * st->lastg[i]);
1401
            st->lastg[i] = g[i];
1402
        }
1403
1404
        interp_band_gain(gf, g);
1405
1406
        for (int i = 0; i < FREQ_SIZE; i++) {
1407
            X[i].re *= gf[i];
1408
            X[i].im *= gf[i];
1409
        }
1410
    }
1411
1412
    frame_synthesis(s, st, out, X);
1413
    memcpy(history, in, FRAME_SIZE * sizeof(*history));
1414
1415
    return vad_prob;
1416
}
1417
1418
typedef struct ThreadData {
1419
    AVFrame *in, *out;
1420
} ThreadData;
1421
1422
static int rnnoise_channels(AVFilterContext *ctx, void *arg, int jobnr, int nb_jobs)
1423
{
1424
    AudioRNNContext *s = ctx->priv;
1425
    ThreadData *td = arg;
1426
    AVFrame *in = td->in;
1427
    AVFrame *out = td->out;
1428
    const int start = (out->channels * jobnr) / nb_jobs;
1429
    const int end = (out->channels * (jobnr+1)) / nb_jobs;
1430
1431
    for (int ch = start; ch < end; ch++) {
1432
        rnnoise_channel(s, &s->st[ch],
1433
                        (float *)out->extended_data[ch],
1434
                        (const float *)in->extended_data[ch],
1435
                        ctx->is_disabled);
1436
    }
1437
1438
    return 0;
1439
}
1440
1441
static int filter_frame(AVFilterLink *inlink, AVFrame *in)
1442
{
1443
    AVFilterContext *ctx = inlink->dst;
1444
    AVFilterLink *outlink = ctx->outputs[0];
1445
    AVFrame *out = NULL;
1446
    ThreadData td;
1447
1448
    out = ff_get_audio_buffer(outlink, FRAME_SIZE);
1449
    if (!out) {
1450
        av_frame_free(&in);
1451
        return AVERROR(ENOMEM);
1452
    }
1453
    out->pts = in->pts;
1454
1455
    td.in = in; td.out = out;
1456
    ctx->internal->execute(ctx, rnnoise_channels, &td, NULL, FFMIN(outlink->channels,
1457
                                                                   ff_filter_get_nb_threads(ctx)));
1458
1459
    av_frame_free(&in);
1460
    return ff_filter_frame(outlink, out);
1461
}
1462
1463
static int activate(AVFilterContext *ctx)
1464
{
1465
    AVFilterLink *inlink = ctx->inputs[0];
1466
    AVFilterLink *outlink = ctx->outputs[0];
1467
    AVFrame *in = NULL;
1468
    int ret;
1469
1470
    FF_FILTER_FORWARD_STATUS_BACK(outlink, inlink);
1471
1472
    ret = ff_inlink_consume_samples(inlink, FRAME_SIZE, FRAME_SIZE, &in);
1473
    if (ret < 0)
1474
        return ret;
1475
1476
    if (ret > 0)
1477
        return filter_frame(inlink, in);
1478
1479
    FF_FILTER_FORWARD_STATUS(inlink, outlink);
1480
    FF_FILTER_FORWARD_WANTED(outlink, inlink);
1481
1482
    return FFERROR_NOT_READY;
1483
}
1484
1485
static int open_model(AVFilterContext *ctx, RNNModel **model)
1486
{
1487
    AudioRNNContext *s = ctx->priv;
1488
    int ret;
1489
    FILE *f;
1490
1491
    if (!s->model_name)
1492
        return AVERROR(EINVAL);
1493
    f = av_fopen_utf8(s->model_name, "r");
1494
    if (!f) {
1495
        av_log(ctx, AV_LOG_ERROR, "Failed to open model file: %s\n", s->model_name);
1496
        return AVERROR(EINVAL);
1497
    }
1498
1499
    ret = rnnoise_model_from_file(f, model);
1500
    fclose(f);
1501
    if (!*model || ret < 0)
1502
        return ret;
1503
1504
    return 0;
1505
}
1506
1507
static av_cold int init(AVFilterContext *ctx)
1508
{
1509
    AudioRNNContext *s = ctx->priv;
1510
    int ret;
1511
1512
    s->fdsp = avpriv_float_dsp_alloc(0);
1513
    if (!s->fdsp)
1514
        return AVERROR(ENOMEM);
1515
1516
    ret = open_model(ctx, &s->model[0]);
1517
    if (ret < 0)
1518
        return ret;
1519
1520
    for (int i = 0; i < FRAME_SIZE; i++) {
1521
        s->window[i] = sin(.5*M_PI*sin(.5*M_PI*(i+.5)/FRAME_SIZE) * sin(.5*M_PI*(i+.5)/FRAME_SIZE));
1522
        s->window[WINDOW_SIZE - 1 - i] = s->window[i];
1523
    }
1524
1525
    for (int i = 0; i < NB_BANDS; i++) {
1526
        for (int j = 0; j < NB_BANDS; j++) {
1527
            s->dct_table[j][i] = cosf((i + .5f) * j * M_PI / NB_BANDS);
1528
            if (j == 0)
1529
                s->dct_table[j][i] *= sqrtf(.5);
1530
        }
1531
    }
1532
1533
    return 0;
1534
}
1535
1536
static void free_model(AVFilterContext *ctx, int n)
1537
{
1538
    AudioRNNContext *s = ctx->priv;
1539
1540
    rnnoise_model_free(s->model[n]);
1541
    s->model[n] = NULL;
1542
1543
    for (int ch = 0; ch < s->channels && s->st; ch++) {
1544
        av_freep(&s->st[ch].rnn[n].vad_gru_state);
1545
        av_freep(&s->st[ch].rnn[n].noise_gru_state);
1546
        av_freep(&s->st[ch].rnn[n].denoise_gru_state);
1547
    }
1548
}
1549
1550
static int process_command(AVFilterContext *ctx, const char *cmd, const char *args,
1551
                           char *res, int res_len, int flags)
1552
{
1553
    AudioRNNContext *s = ctx->priv;
1554
    int ret;
1555
1556
    ret = ff_filter_process_command(ctx, cmd, args, res, res_len, flags);
1557
    if (ret < 0)
1558
        return ret;
1559
1560
    ret = open_model(ctx, &s->model[1]);
1561
    if (ret < 0)
1562
        return ret;
1563
1564
    FFSWAP(RNNModel *, s->model[0], s->model[1]);
1565
    for (int ch = 0; ch < s->channels; ch++)
1566
        FFSWAP(RNNState, s->st[ch].rnn[0], s->st[ch].rnn[1]);
1567
1568
    ret = config_input(ctx->inputs[0]);
1569
    if (ret < 0) {
1570
        for (int ch = 0; ch < s->channels; ch++)
1571
            FFSWAP(RNNState, s->st[ch].rnn[0], s->st[ch].rnn[1]);
1572
        FFSWAP(RNNModel *, s->model[0], s->model[1]);
1573
        return ret;
1574
    }
1575
1576
    free_model(ctx, 1);
1577
    return 0;
1578
}
1579
1580
static av_cold void uninit(AVFilterContext *ctx)
1581
{
1582
    AudioRNNContext *s = ctx->priv;
1583
1584
    av_freep(&s->fdsp);
1585
    free_model(ctx, 0);
1586
    for (int ch = 0; ch < s->channels && s->st; ch++) {
1587
        av_tx_uninit(&s->st[ch].tx);
1588
        av_tx_uninit(&s->st[ch].txi);
1589
    }
1590
    av_freep(&s->st);
1591
}
1592
1593
static const AVFilterPad inputs[] = {
1594
    {
1595
        .name         = "default",
1596
        .type         = AVMEDIA_TYPE_AUDIO,
1597
        .config_props = config_input,
1598
    },
1599
    { NULL }
1600
};
1601
1602
static const AVFilterPad outputs[] = {
1603
    {
1604
        .name          = "default",
1605
        .type          = AVMEDIA_TYPE_AUDIO,
1606
    },
1607
    { NULL }
1608
};
1609
1610
#define OFFSET(x) offsetof(AudioRNNContext, x)
1611
#define AF AV_OPT_FLAG_AUDIO_PARAM|AV_OPT_FLAG_FILTERING_PARAM|AV_OPT_FLAG_RUNTIME_PARAM
1612
1613
static const AVOption arnndn_options[] = {
1614
    { "model", "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1615
    { "m",     "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1616
    { "mix",   "set output vs input mix", OFFSET(mix), AV_OPT_TYPE_FLOAT, {.dbl=1.0},-1, 1, AF },
1617
    { NULL }
1618
};
1619
1620
AVFILTER_DEFINE_CLASS(arnndn);
1621
1622
AVFilter ff_af_arnndn = {
1623
    .name          = "arnndn",
1624
    .description   = NULL_IF_CONFIG_SMALL("Reduce noise from speech using Recurrent Neural Networks."),
1625
    .query_formats = query_formats,
1626
    .priv_size     = sizeof(AudioRNNContext),
1627
    .priv_class    = &arnndn_class,
1628
    .activate      = activate,
1629
    .init          = init,
1630
    .uninit        = uninit,
1631
    .inputs        = inputs,
1632
    .outputs       = outputs,
1633
    .flags         = AVFILTER_FLAG_SUPPORT_TIMELINE_INTERNAL |
1634
                     AVFILTER_FLAG_SLICE_THREADS,
1635
    .process_command = process_command,
1636
};