GCC Code Coverage Report
Directory: ../../../ffmpeg/ Exec Total Coverage
File: src/libavfilter/af_arnndn.c Lines: 0 711 0.0 %
Date: 2021-01-22 05:18:52 Branches: 0 596 0.0 %

Line Branch Exec Source
1
/*
2
 * Copyright (c) 2018 Gregor Richards
3
 * Copyright (c) 2017 Mozilla
4
 * Copyright (c) 2005-2009 Xiph.Org Foundation
5
 * Copyright (c) 2007-2008 CSIRO
6
 * Copyright (c) 2008-2011 Octasic Inc.
7
 * Copyright (c) Jean-Marc Valin
8
 * Copyright (c) 2019 Paul B Mahol
9
 *
10
 * Redistribution and use in source and binary forms, with or without
11
 * modification, are permitted provided that the following conditions
12
 * are met:
13
 *
14
 * - Redistributions of source code must retain the above copyright
15
 *   notice, this list of conditions and the following disclaimer.
16
 *
17
 * - Redistributions in binary form must reproduce the above copyright
18
 *   notice, this list of conditions and the following disclaimer in the
19
 *   documentation and/or other materials provided with the distribution.
20
 *
21
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22
 * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23
 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24
 * A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
25
 * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26
 * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27
 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28
 * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29
 * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30
 * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32
 */
33
34
#include <float.h>
35
36
#include "libavutil/avassert.h"
37
#include "libavutil/avstring.h"
38
#include "libavutil/float_dsp.h"
39
#include "libavutil/mem_internal.h"
40
#include "libavutil/opt.h"
41
#include "libavutil/tx.h"
42
#include "avfilter.h"
43
#include "audio.h"
44
#include "filters.h"
45
#include "formats.h"
46
47
#define FRAME_SIZE_SHIFT 2
48
#define FRAME_SIZE (120<<FRAME_SIZE_SHIFT)
49
#define WINDOW_SIZE (2*FRAME_SIZE)
50
#define FREQ_SIZE (FRAME_SIZE + 1)
51
52
#define PITCH_MIN_PERIOD 60
53
#define PITCH_MAX_PERIOD 768
54
#define PITCH_FRAME_SIZE 960
55
#define PITCH_BUF_SIZE (PITCH_MAX_PERIOD+PITCH_FRAME_SIZE)
56
57
#define SQUARE(x) ((x)*(x))
58
59
#define NB_BANDS 22
60
61
#define CEPS_MEM 8
62
#define NB_DELTA_CEPS 6
63
64
#define NB_FEATURES (NB_BANDS+3*NB_DELTA_CEPS+2)
65
66
#define WEIGHTS_SCALE (1.f/256)
67
68
#define MAX_NEURONS 128
69
70
#define ACTIVATION_TANH    0
71
#define ACTIVATION_SIGMOID 1
72
#define ACTIVATION_RELU    2
73
74
#define Q15ONE 1.0f
75
76
typedef struct DenseLayer {
77
    const float *bias;
78
    const float *input_weights;
79
    int nb_inputs;
80
    int nb_neurons;
81
    int activation;
82
} DenseLayer;
83
84
typedef struct GRULayer {
85
    const float *bias;
86
    const float *input_weights;
87
    const float *recurrent_weights;
88
    int nb_inputs;
89
    int nb_neurons;
90
    int activation;
91
} GRULayer;
92
93
typedef struct RNNModel {
94
    int input_dense_size;
95
    const DenseLayer *input_dense;
96
97
    int vad_gru_size;
98
    const GRULayer *vad_gru;
99
100
    int noise_gru_size;
101
    const GRULayer *noise_gru;
102
103
    int denoise_gru_size;
104
    const GRULayer *denoise_gru;
105
106
    int denoise_output_size;
107
    const DenseLayer *denoise_output;
108
109
    int vad_output_size;
110
    const DenseLayer *vad_output;
111
} RNNModel;
112
113
typedef struct RNNState {
114
    float *vad_gru_state;
115
    float *noise_gru_state;
116
    float *denoise_gru_state;
117
    RNNModel *model;
118
} RNNState;
119
120
typedef struct DenoiseState {
121
    float analysis_mem[FRAME_SIZE];
122
    float cepstral_mem[CEPS_MEM][NB_BANDS];
123
    int memid;
124
    DECLARE_ALIGNED(32, float, synthesis_mem)[FRAME_SIZE];
125
    float pitch_buf[PITCH_BUF_SIZE];
126
    float pitch_enh_buf[PITCH_BUF_SIZE];
127
    float last_gain;
128
    int last_period;
129
    float mem_hp_x[2];
130
    float lastg[NB_BANDS];
131
    float history[FRAME_SIZE];
132
    RNNState rnn;
133
    AVTXContext *tx, *txi;
134
    av_tx_fn tx_fn, txi_fn;
135
} DenoiseState;
136
137
typedef struct AudioRNNContext {
138
    const AVClass *class;
139
140
    char *model_name;
141
    float mix;
142
143
    int channels;
144
    DenoiseState *st;
145
146
    DECLARE_ALIGNED(32, float, window)[WINDOW_SIZE];
147
    DECLARE_ALIGNED(32, float, dct_table)[FFALIGN(NB_BANDS, 4)][FFALIGN(NB_BANDS, 4)];
148
149
    RNNModel *model;
150
151
    AVFloatDSPContext *fdsp;
152
} AudioRNNContext;
153
154
#define F_ACTIVATION_TANH       0
155
#define F_ACTIVATION_SIGMOID    1
156
#define F_ACTIVATION_RELU       2
157
158
static void rnnoise_model_free(RNNModel *model)
159
{
160
#define FREE_MAYBE(ptr) do { if (ptr) free(ptr); } while (0)
161
#define FREE_DENSE(name) do { \
162
    if (model->name) { \
163
        av_free((void *) model->name->input_weights); \
164
        av_free((void *) model->name->bias); \
165
        av_free((void *) model->name); \
166
    } \
167
    } while (0)
168
#define FREE_GRU(name) do { \
169
    if (model->name) { \
170
        av_free((void *) model->name->input_weights); \
171
        av_free((void *) model->name->recurrent_weights); \
172
        av_free((void *) model->name->bias); \
173
        av_free((void *) model->name); \
174
    } \
175
    } while (0)
176
177
    if (!model)
178
        return;
179
    FREE_DENSE(input_dense);
180
    FREE_GRU(vad_gru);
181
    FREE_GRU(noise_gru);
182
    FREE_GRU(denoise_gru);
183
    FREE_DENSE(denoise_output);
184
    FREE_DENSE(vad_output);
185
    av_free(model);
186
}
187
188
static RNNModel *rnnoise_model_from_file(FILE *f)
189
{
190
    RNNModel *ret;
191
    DenseLayer *input_dense;
192
    GRULayer *vad_gru;
193
    GRULayer *noise_gru;
194
    GRULayer *denoise_gru;
195
    DenseLayer *denoise_output;
196
    DenseLayer *vad_output;
197
    int in;
198
199
    if (fscanf(f, "rnnoise-nu model file version %d\n", &in) != 1 || in != 1)
200
        return NULL;
201
202
    ret = av_calloc(1, sizeof(RNNModel));
203
    if (!ret)
204
        return NULL;
205
206
#define ALLOC_LAYER(type, name) \
207
    name = av_calloc(1, sizeof(type)); \
208
    if (!name) { \
209
        rnnoise_model_free(ret); \
210
        return NULL; \
211
    } \
212
    ret->name = name
213
214
    ALLOC_LAYER(DenseLayer, input_dense);
215
    ALLOC_LAYER(GRULayer, vad_gru);
216
    ALLOC_LAYER(GRULayer, noise_gru);
217
    ALLOC_LAYER(GRULayer, denoise_gru);
218
    ALLOC_LAYER(DenseLayer, denoise_output);
219
    ALLOC_LAYER(DenseLayer, vad_output);
220
221
#define INPUT_VAL(name) do { \
222
    if (fscanf(f, "%d", &in) != 1 || in < 0 || in > 128) { \
223
        rnnoise_model_free(ret); \
224
        return NULL; \
225
    } \
226
    name = in; \
227
    } while (0)
228
229
#define INPUT_ACTIVATION(name) do { \
230
    int activation; \
231
    INPUT_VAL(activation); \
232
    switch (activation) { \
233
    case F_ACTIVATION_SIGMOID: \
234
        name = ACTIVATION_SIGMOID; \
235
        break; \
236
    case F_ACTIVATION_RELU: \
237
        name = ACTIVATION_RELU; \
238
        break; \
239
    default: \
240
        name = ACTIVATION_TANH; \
241
    } \
242
    } while (0)
243
244
#define INPUT_ARRAY(name, len) do { \
245
    float *values = av_calloc((len), sizeof(float)); \
246
    if (!values) { \
247
        rnnoise_model_free(ret); \
248
        return NULL; \
249
    } \
250
    name = values; \
251
    for (int i = 0; i < (len); i++) { \
252
        if (fscanf(f, "%d", &in) != 1) { \
253
            rnnoise_model_free(ret); \
254
            return NULL; \
255
        } \
256
        values[i] = in; \
257
    } \
258
    } while (0)
259
260
#define INPUT_ARRAY3(name, len0, len1, len2) do { \
261
    float *values = av_calloc(FFALIGN((len0), 4) * FFALIGN((len1), 4) * (len2), sizeof(float)); \
262
    if (!values) { \
263
        rnnoise_model_free(ret); \
264
        return NULL; \
265
    } \
266
    name = values; \
267
    for (int k = 0; k < (len0); k++) { \
268
        for (int i = 0; i < (len2); i++) { \
269
            for (int j = 0; j < (len1); j++) { \
270
                if (fscanf(f, "%d", &in) != 1) { \
271
                    rnnoise_model_free(ret); \
272
                    return NULL; \
273
                } \
274
                values[j * (len2) * FFALIGN((len0), 4) + i * FFALIGN((len0), 4) + k] = in; \
275
            } \
276
        } \
277
    } \
278
    } while (0)
279
280
#define INPUT_DENSE(name) do { \
281
    INPUT_VAL(name->nb_inputs); \
282
    INPUT_VAL(name->nb_neurons); \
283
    ret->name ## _size = name->nb_neurons; \
284
    INPUT_ACTIVATION(name->activation); \
285
    INPUT_ARRAY(name->input_weights, name->nb_inputs * name->nb_neurons); \
286
    INPUT_ARRAY(name->bias, name->nb_neurons); \
287
    } while (0)
288
289
#define INPUT_GRU(name) do { \
290
    INPUT_VAL(name->nb_inputs); \
291
    INPUT_VAL(name->nb_neurons); \
292
    ret->name ## _size = name->nb_neurons; \
293
    INPUT_ACTIVATION(name->activation); \
294
    INPUT_ARRAY3(name->input_weights, name->nb_inputs, name->nb_neurons, 3); \
295
    INPUT_ARRAY3(name->recurrent_weights, name->nb_neurons, name->nb_neurons, 3); \
296
    INPUT_ARRAY(name->bias, name->nb_neurons * 3); \
297
    } while (0)
298
299
    INPUT_DENSE(input_dense);
300
    INPUT_GRU(vad_gru);
301
    INPUT_GRU(noise_gru);
302
    INPUT_GRU(denoise_gru);
303
    INPUT_DENSE(denoise_output);
304
    INPUT_DENSE(vad_output);
305
306
    if (vad_output->nb_neurons != 1) {
307
        rnnoise_model_free(ret);
308
        return NULL;
309
    }
310
311
    return ret;
312
}
313
314
static int query_formats(AVFilterContext *ctx)
315
{
316
    AVFilterFormats *formats = NULL;
317
    AVFilterChannelLayouts *layouts = NULL;
318
    static const enum AVSampleFormat sample_fmts[] = {
319
        AV_SAMPLE_FMT_FLTP,
320
        AV_SAMPLE_FMT_NONE
321
    };
322
    int ret, sample_rates[] = { 48000, -1 };
323
324
    formats = ff_make_format_list(sample_fmts);
325
    if (!formats)
326
        return AVERROR(ENOMEM);
327
    ret = ff_set_common_formats(ctx, formats);
328
    if (ret < 0)
329
        return ret;
330
331
    layouts = ff_all_channel_counts();
332
    if (!layouts)
333
        return AVERROR(ENOMEM);
334
335
    ret = ff_set_common_channel_layouts(ctx, layouts);
336
    if (ret < 0)
337
        return ret;
338
339
    formats = ff_make_format_list(sample_rates);
340
    if (!formats)
341
        return AVERROR(ENOMEM);
342
    return ff_set_common_samplerates(ctx, formats);
343
}
344
345
static int config_input(AVFilterLink *inlink)
346
{
347
    AVFilterContext *ctx = inlink->dst;
348
    AudioRNNContext *s = ctx->priv;
349
    int ret;
350
351
    s->channels = inlink->channels;
352
353
    s->st = av_calloc(s->channels, sizeof(DenoiseState));
354
    if (!s->st)
355
        return AVERROR(ENOMEM);
356
357
    for (int i = 0; i < s->channels; i++) {
358
        DenoiseState *st = &s->st[i];
359
360
        st->rnn.model = s->model;
361
        st->rnn.vad_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->vad_gru_size, 16));
362
        st->rnn.noise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->noise_gru_size, 16));
363
        st->rnn.denoise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->denoise_gru_size, 16));
364
        if (!st->rnn.vad_gru_state ||
365
            !st->rnn.noise_gru_state ||
366
            !st->rnn.denoise_gru_state)
367
            return AVERROR(ENOMEM);
368
369
        ret = av_tx_init(&st->tx, &st->tx_fn, AV_TX_FLOAT_FFT, 0, WINDOW_SIZE, NULL, 0);
370
        if (ret < 0)
371
            return ret;
372
373
        ret = av_tx_init(&st->txi, &st->txi_fn, AV_TX_FLOAT_FFT, 1, WINDOW_SIZE, NULL, 0);
374
        if (ret < 0)
375
            return ret;
376
    }
377
378
    return 0;
379
}
380
381
static void biquad(float *y, float mem[2], const float *x,
382
                   const float *b, const float *a, int N)
383
{
384
    for (int i = 0; i < N; i++) {
385
        float xi, yi;
386
387
        xi = x[i];
388
        yi = x[i] + mem[0];
389
        mem[0] = mem[1] + (b[0]*xi - a[0]*yi);
390
        mem[1] = (b[1]*xi - a[1]*yi);
391
        y[i] = yi;
392
    }
393
}
394
395
#define RNN_MOVE(dst, src, n) (memmove((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
396
#define RNN_CLEAR(dst, n) (memset((dst), 0, (n)*sizeof(*(dst))))
397
#define RNN_COPY(dst, src, n) (memcpy((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
398
399
static void forward_transform(DenoiseState *st, AVComplexFloat *out, const float *in)
400
{
401
    AVComplexFloat x[WINDOW_SIZE];
402
    AVComplexFloat y[WINDOW_SIZE];
403
404
    for (int i = 0; i < WINDOW_SIZE; i++) {
405
        x[i].re = in[i];
406
        x[i].im = 0;
407
    }
408
409
    st->tx_fn(st->tx, y, x, sizeof(float));
410
411
    RNN_COPY(out, y, FREQ_SIZE);
412
}
413
414
static void inverse_transform(DenoiseState *st, float *out, const AVComplexFloat *in)
415
{
416
    AVComplexFloat x[WINDOW_SIZE];
417
    AVComplexFloat y[WINDOW_SIZE];
418
419
    RNN_COPY(x, in, FREQ_SIZE);
420
421
    for (int i = FREQ_SIZE; i < WINDOW_SIZE; i++) {
422
        x[i].re =  x[WINDOW_SIZE - i].re;
423
        x[i].im = -x[WINDOW_SIZE - i].im;
424
    }
425
426
    st->txi_fn(st->txi, y, x, sizeof(float));
427
428
    for (int i = 0; i < WINDOW_SIZE; i++)
429
        out[i] = y[i].re / WINDOW_SIZE;
430
}
431
432
static const uint8_t eband5ms[] = {
433
/*0  200 400 600 800  1k 1.2 1.4 1.6  2k 2.4 2.8 3.2  4k 4.8 5.6 6.8  8k 9.6 12k 15.6 20k*/
434
  0,  1,  2,  3,  4,   5, 6,  7,  8,  10, 12, 14, 16, 20, 24, 28, 34, 40, 48, 60, 78, 100
435
};
436
437
static void compute_band_energy(float *bandE, const AVComplexFloat *X)
438
{
439
    float sum[NB_BANDS] = {0};
440
441
    for (int i = 0; i < NB_BANDS - 1; i++) {
442
        int band_size;
443
444
        band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
445
        for (int j = 0; j < band_size; j++) {
446
            float tmp, frac = (float)j / band_size;
447
448
            tmp         = SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].re);
449
            tmp        += SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].im);
450
            sum[i]     += (1.f - frac) * tmp;
451
            sum[i + 1] +=        frac  * tmp;
452
        }
453
    }
454
455
    sum[0] *= 2;
456
    sum[NB_BANDS - 1] *= 2;
457
458
    for (int i = 0; i < NB_BANDS; i++)
459
        bandE[i] = sum[i];
460
}
461
462
static void compute_band_corr(float *bandE, const AVComplexFloat *X, const AVComplexFloat *P)
463
{
464
    float sum[NB_BANDS] = { 0 };
465
466
    for (int i = 0; i < NB_BANDS - 1; i++) {
467
        int band_size;
468
469
        band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
470
        for (int j = 0; j < band_size; j++) {
471
            float tmp, frac = (float)j / band_size;
472
473
            tmp  = X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re;
474
            tmp += X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im;
475
            sum[i]     += (1 - frac) * tmp;
476
            sum[i + 1] +=      frac  * tmp;
477
        }
478
    }
479
480
    sum[0] *= 2;
481
    sum[NB_BANDS-1] *= 2;
482
483
    for (int i = 0; i < NB_BANDS; i++)
484
        bandE[i] = sum[i];
485
}
486
487
static void frame_analysis(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, float *Ex, const float *in)
488
{
489
    LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
490
491
    RNN_COPY(x, st->analysis_mem, FRAME_SIZE);
492
    RNN_COPY(x + FRAME_SIZE, in, FRAME_SIZE);
493
    RNN_COPY(st->analysis_mem, in, FRAME_SIZE);
494
    s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
495
    forward_transform(st, X, x);
496
    compute_band_energy(Ex, X);
497
}
498
499
static void frame_synthesis(AudioRNNContext *s, DenoiseState *st, float *out, const AVComplexFloat *y)
500
{
501
    LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
502
    const float *src = st->history;
503
    const float mix = s->mix;
504
    const float imix = 1.f - FFMAX(mix, 0.f);
505
506
    inverse_transform(st, x, y);
507
    s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
508
    s->fdsp->vector_fmac_scalar(x, st->synthesis_mem, 1.f, FRAME_SIZE);
509
    RNN_COPY(out, x, FRAME_SIZE);
510
    RNN_COPY(st->synthesis_mem, &x[FRAME_SIZE], FRAME_SIZE);
511
512
    for (int n = 0; n < FRAME_SIZE; n++)
513
        out[n] = out[n] * mix + src[n] * imix;
514
}
515
516
static inline void xcorr_kernel(const float *x, const float *y, float sum[4], int len)
517
{
518
    float y_0, y_1, y_2, y_3 = 0;
519
    int j;
520
521
    y_0 = *y++;
522
    y_1 = *y++;
523
    y_2 = *y++;
524
525
    for (j = 0; j < len - 3; j += 4) {
526
        float tmp;
527
528
        tmp = *x++;
529
        y_3 = *y++;
530
        sum[0] += tmp * y_0;
531
        sum[1] += tmp * y_1;
532
        sum[2] += tmp * y_2;
533
        sum[3] += tmp * y_3;
534
        tmp = *x++;
535
        y_0 = *y++;
536
        sum[0] += tmp * y_1;
537
        sum[1] += tmp * y_2;
538
        sum[2] += tmp * y_3;
539
        sum[3] += tmp * y_0;
540
        tmp = *x++;
541
        y_1 = *y++;
542
        sum[0] += tmp * y_2;
543
        sum[1] += tmp * y_3;
544
        sum[2] += tmp * y_0;
545
        sum[3] += tmp * y_1;
546
        tmp = *x++;
547
        y_2 = *y++;
548
        sum[0] += tmp * y_3;
549
        sum[1] += tmp * y_0;
550
        sum[2] += tmp * y_1;
551
        sum[3] += tmp * y_2;
552
    }
553
554
    if (j++ < len) {
555
        float tmp = *x++;
556
557
        y_3 = *y++;
558
        sum[0] += tmp * y_0;
559
        sum[1] += tmp * y_1;
560
        sum[2] += tmp * y_2;
561
        sum[3] += tmp * y_3;
562
    }
563
564
    if (j++ < len) {
565
        float tmp=*x++;
566
567
        y_0 = *y++;
568
        sum[0] += tmp * y_1;
569
        sum[1] += tmp * y_2;
570
        sum[2] += tmp * y_3;
571
        sum[3] += tmp * y_0;
572
    }
573
574
    if (j < len) {
575
        float tmp=*x++;
576
577
        y_1 = *y++;
578
        sum[0] += tmp * y_2;
579
        sum[1] += tmp * y_3;
580
        sum[2] += tmp * y_0;
581
        sum[3] += tmp * y_1;
582
    }
583
}
584
585
static inline float celt_inner_prod(const float *x,
586
                                    const float *y, int N)
587
{
588
    float xy = 0.f;
589
590
    for (int i = 0; i < N; i++)
591
        xy += x[i] * y[i];
592
593
    return xy;
594
}
595
596
static void celt_pitch_xcorr(const float *x, const float *y,
597
                             float *xcorr, int len, int max_pitch)
598
{
599
    int i;
600
601
    for (i = 0; i < max_pitch - 3; i += 4) {
602
        float sum[4] = { 0, 0, 0, 0};
603
604
        xcorr_kernel(x, y + i, sum, len);
605
606
        xcorr[i]     = sum[0];
607
        xcorr[i + 1] = sum[1];
608
        xcorr[i + 2] = sum[2];
609
        xcorr[i + 3] = sum[3];
610
    }
611
    /* In case max_pitch isn't a multiple of 4, do non-unrolled version. */
612
    for (; i < max_pitch; i++) {
613
        xcorr[i] = celt_inner_prod(x, y + i, len);
614
    }
615
}
616
617
static int celt_autocorr(const float *x,   /*  in: [0...n-1] samples x   */
618
                         float       *ac,  /* out: [0...lag-1] ac values */
619
                         const float *window,
620
                         int          overlap,
621
                         int          lag,
622
                         int          n)
623
{
624
    int fastN = n - lag;
625
    int shift;
626
    const float *xptr;
627
    float xx[PITCH_BUF_SIZE>>1];
628
629
    if (overlap == 0) {
630
        xptr = x;
631
    } else {
632
        for (int i = 0; i < n; i++)
633
            xx[i] = x[i];
634
        for (int i = 0; i < overlap; i++) {
635
            xx[i] = x[i] * window[i];
636
            xx[n-i-1] = x[n-i-1] * window[i];
637
        }
638
        xptr = xx;
639
    }
640
641
    shift = 0;
642
    celt_pitch_xcorr(xptr, xptr, ac, fastN, lag+1);
643
644
    for (int k = 0; k <= lag; k++) {
645
        float d = 0.f;
646
647
        for (int i = k + fastN; i < n; i++)
648
            d += xptr[i] * xptr[i-k];
649
        ac[k] += d;
650
    }
651
652
    return shift;
653
}
654
655
static void celt_lpc(float *lpc, /* out: [0...p-1] LPC coefficients      */
656
                const float *ac,   /* in:  [0...p] autocorrelation values  */
657
                          int p)
658
{
659
    float r, error = ac[0];
660
661
    RNN_CLEAR(lpc, p);
662
    if (ac[0] != 0) {
663
        for (int i = 0; i < p; i++) {
664
            /* Sum up this iteration's reflection coefficient */
665
            float rr = 0;
666
            for (int j = 0; j < i; j++)
667
                rr += (lpc[j] * ac[i - j]);
668
            rr += ac[i + 1];
669
            r = -rr/error;
670
            /*  Update LPC coefficients and total error */
671
            lpc[i] = r;
672
            for (int j = 0; j < (i + 1) >> 1; j++) {
673
                float tmp1, tmp2;
674
                tmp1 = lpc[j];
675
                tmp2 = lpc[i-1-j];
676
                lpc[j]     = tmp1 + (r*tmp2);
677
                lpc[i-1-j] = tmp2 + (r*tmp1);
678
            }
679
680
            error = error - (r * r *error);
681
            /* Bail out once we get 30 dB gain */
682
            if (error < .001f * ac[0])
683
                break;
684
        }
685
    }
686
}
687
688
static void celt_fir5(const float *x,
689
                      const float *num,
690
                      float *y,
691
                      int N,
692
                      float *mem)
693
{
694
    float num0, num1, num2, num3, num4;
695
    float mem0, mem1, mem2, mem3, mem4;
696
697
    num0 = num[0];
698
    num1 = num[1];
699
    num2 = num[2];
700
    num3 = num[3];
701
    num4 = num[4];
702
    mem0 = mem[0];
703
    mem1 = mem[1];
704
    mem2 = mem[2];
705
    mem3 = mem[3];
706
    mem4 = mem[4];
707
708
    for (int i = 0; i < N; i++) {
709
        float sum = x[i];
710
711
        sum += (num0*mem0);
712
        sum += (num1*mem1);
713
        sum += (num2*mem2);
714
        sum += (num3*mem3);
715
        sum += (num4*mem4);
716
        mem4 = mem3;
717
        mem3 = mem2;
718
        mem2 = mem1;
719
        mem1 = mem0;
720
        mem0 = x[i];
721
        y[i] = sum;
722
    }
723
724
    mem[0] = mem0;
725
    mem[1] = mem1;
726
    mem[2] = mem2;
727
    mem[3] = mem3;
728
    mem[4] = mem4;
729
}
730
731
static void pitch_downsample(float *x[], float *x_lp,
732
                             int len, int C)
733
{
734
    float ac[5];
735
    float tmp=Q15ONE;
736
    float lpc[4], mem[5]={0,0,0,0,0};
737
    float lpc2[5];
738
    float c1 = .8f;
739
740
    for (int i = 1; i < len >> 1; i++)
741
        x_lp[i] = .5f * (.5f * (x[0][(2*i-1)]+x[0][(2*i+1)])+x[0][2*i]);
742
    x_lp[0] = .5f * (.5f * (x[0][1])+x[0][0]);
743
    if (C==2) {
744
        for (int i = 1; i < len >> 1; i++)
745
            x_lp[i] += (.5f * (.5f * (x[1][(2*i-1)]+x[1][(2*i+1)])+x[1][2*i]));
746
        x_lp[0] += .5f * (.5f * (x[1][1])+x[1][0]);
747
    }
748
749
    celt_autocorr(x_lp, ac, NULL, 0, 4, len>>1);
750
751
    /* Noise floor -40 dB */
752
    ac[0] *= 1.0001f;
753
    /* Lag windowing */
754
    for (int i = 1; i <= 4; i++) {
755
        /*ac[i] *= exp(-.5*(2*M_PI*.002*i)*(2*M_PI*.002*i));*/
756
        ac[i] -= ac[i]*(.008f*i)*(.008f*i);
757
    }
758
759
    celt_lpc(lpc, ac, 4);
760
    for (int i = 0; i < 4; i++) {
761
        tmp = .9f * tmp;
762
        lpc[i] = (lpc[i] * tmp);
763
    }
764
    /* Add a zero */
765
    lpc2[0] = lpc[0] + .8f;
766
    lpc2[1] = lpc[1] + (c1 * lpc[0]);
767
    lpc2[2] = lpc[2] + (c1 * lpc[1]);
768
    lpc2[3] = lpc[3] + (c1 * lpc[2]);
769
    lpc2[4] = (c1 * lpc[3]);
770
    celt_fir5(x_lp, lpc2, x_lp, len>>1, mem);
771
}
772
773
static inline void dual_inner_prod(const float *x, const float *y01, const float *y02,
774
                                   int N, float *xy1, float *xy2)
775
{
776
    float xy01 = 0, xy02 = 0;
777
778
    for (int i = 0; i < N; i++) {
779
        xy01 += (x[i] * y01[i]);
780
        xy02 += (x[i] * y02[i]);
781
    }
782
783
    *xy1 = xy01;
784
    *xy2 = xy02;
785
}
786
787
static float compute_pitch_gain(float xy, float xx, float yy)
788
{
789
    return xy / sqrtf(1.f + xx * yy);
790
}
791
792
static const uint8_t second_check[16] = {0, 0, 3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2};
793
static float remove_doubling(float *x, int maxperiod, int minperiod, int N,
794
                             int *T0_, int prev_period, float prev_gain)
795
{
796
    int k, i, T, T0;
797
    float g, g0;
798
    float pg;
799
    float xy,xx,yy,xy2;
800
    float xcorr[3];
801
    float best_xy, best_yy;
802
    int offset;
803
    int minperiod0;
804
    float yy_lookup[PITCH_MAX_PERIOD+1];
805
806
    minperiod0 = minperiod;
807
    maxperiod /= 2;
808
    minperiod /= 2;
809
    *T0_ /= 2;
810
    prev_period /= 2;
811
    N /= 2;
812
    x += maxperiod;
813
    if (*T0_>=maxperiod)
814
        *T0_=maxperiod-1;
815
816
    T = T0 = *T0_;
817
    dual_inner_prod(x, x, x-T0, N, &xx, &xy);
818
    yy_lookup[0] = xx;
819
    yy=xx;
820
    for (i = 1; i <= maxperiod; i++) {
821
        yy = yy+(x[-i] * x[-i])-(x[N-i] * x[N-i]);
822
        yy_lookup[i] = FFMAX(0, yy);
823
    }
824
    yy = yy_lookup[T0];
825
    best_xy = xy;
826
    best_yy = yy;
827
    g = g0 = compute_pitch_gain(xy, xx, yy);
828
    /* Look for any pitch at T/k */
829
    for (k = 2; k <= 15; k++) {
830
        int T1, T1b;
831
        float g1;
832
        float cont=0;
833
        float thresh;
834
        T1 = (2*T0+k)/(2*k);
835
        if (T1 < minperiod)
836
            break;
837
        /* Look for another strong correlation at T1b */
838
        if (k==2)
839
        {
840
            if (T1+T0>maxperiod)
841
                T1b = T0;
842
            else
843
                T1b = T0+T1;
844
        } else
845
        {
846
            T1b = (2*second_check[k]*T0+k)/(2*k);
847
        }
848
        dual_inner_prod(x, &x[-T1], &x[-T1b], N, &xy, &xy2);
849
        xy = .5f * (xy + xy2);
850
        yy = .5f * (yy_lookup[T1] + yy_lookup[T1b]);
851
        g1 = compute_pitch_gain(xy, xx, yy);
852
        if (FFABS(T1-prev_period)<=1)
853
            cont = prev_gain;
854
        else if (FFABS(T1-prev_period)<=2 && 5 * k * k < T0)
855
            cont = prev_gain * .5f;
856
        else
857
            cont = 0;
858
        thresh = FFMAX(.3f, (.7f * g0) - cont);
859
        /* Bias against very high pitch (very short period) to avoid false-positives
860
           due to short-term correlation */
861
        if (T1<3*minperiod)
862
            thresh = FFMAX(.4f, (.85f * g0) - cont);
863
        else if (T1<2*minperiod)
864
            thresh = FFMAX(.5f, (.9f * g0) - cont);
865
        if (g1 > thresh)
866
        {
867
            best_xy = xy;
868
            best_yy = yy;
869
            T = T1;
870
            g = g1;
871
        }
872
    }
873
    best_xy = FFMAX(0, best_xy);
874
    if (best_yy <= best_xy)
875
        pg = Q15ONE;
876
    else
877
        pg = best_xy/(best_yy + 1);
878
879
    for (k = 0; k < 3; k++)
880
        xcorr[k] = celt_inner_prod(x, x-(T+k-1), N);
881
    if ((xcorr[2]-xcorr[0]) > .7f * (xcorr[1]-xcorr[0]))
882
        offset = 1;
883
    else if ((xcorr[0]-xcorr[2]) > (.7f * (xcorr[1] - xcorr[2])))
884
        offset = -1;
885
    else
886
        offset = 0;
887
    if (pg > g)
888
        pg = g;
889
    *T0_ = 2*T+offset;
890
891
    if (*T0_<minperiod0)
892
        *T0_=minperiod0;
893
    return pg;
894
}
895
896
static void find_best_pitch(float *xcorr, float *y, int len,
897
                            int max_pitch, int *best_pitch)
898
{
899
    float best_num[2];
900
    float best_den[2];
901
    float Syy = 1.f;
902
903
    best_num[0] = -1;
904
    best_num[1] = -1;
905
    best_den[0] = 0;
906
    best_den[1] = 0;
907
    best_pitch[0] = 0;
908
    best_pitch[1] = 1;
909
910
    for (int j = 0; j < len; j++)
911
        Syy += y[j] * y[j];
912
913
    for (int i = 0; i < max_pitch; i++) {
914
        if (xcorr[i]>0) {
915
            float num;
916
            float xcorr16;
917
918
            xcorr16 = xcorr[i];
919
            /* Considering the range of xcorr16, this should avoid both underflows
920
               and overflows (inf) when squaring xcorr16 */
921
            xcorr16 *= 1e-12f;
922
            num = xcorr16 * xcorr16;
923
            if ((num * best_den[1]) > (best_num[1] * Syy)) {
924
                if ((num * best_den[0]) > (best_num[0] * Syy)) {
925
                    best_num[1] = best_num[0];
926
                    best_den[1] = best_den[0];
927
                    best_pitch[1] = best_pitch[0];
928
                    best_num[0] = num;
929
                    best_den[0] = Syy;
930
                    best_pitch[0] = i;
931
                } else {
932
                    best_num[1] = num;
933
                    best_den[1] = Syy;
934
                    best_pitch[1] = i;
935
                }
936
            }
937
        }
938
        Syy += y[i+len]*y[i+len] - y[i] * y[i];
939
        Syy = FFMAX(1, Syy);
940
    }
941
}
942
943
static void pitch_search(const float *x_lp, float *y,
944
                         int len, int max_pitch, int *pitch)
945
{
946
    int lag;
947
    int best_pitch[2]={0,0};
948
    int offset;
949
950
    float x_lp4[WINDOW_SIZE];
951
    float y_lp4[WINDOW_SIZE];
952
    float xcorr[WINDOW_SIZE];
953
954
    lag = len+max_pitch;
955
956
    /* Downsample by 2 again */
957
    for (int j = 0; j < len >> 2; j++)
958
        x_lp4[j] = x_lp[2*j];
959
    for (int j = 0; j < lag >> 2; j++)
960
        y_lp4[j] = y[2*j];
961
962
    /* Coarse search with 4x decimation */
963
964
    celt_pitch_xcorr(x_lp4, y_lp4, xcorr, len>>2, max_pitch>>2);
965
966
    find_best_pitch(xcorr, y_lp4, len>>2, max_pitch>>2, best_pitch);
967
968
    /* Finer search with 2x decimation */
969
    for (int i = 0; i < max_pitch >> 1; i++) {
970
        float sum;
971
        xcorr[i] = 0;
972
        if (FFABS(i-2*best_pitch[0])>2 && FFABS(i-2*best_pitch[1])>2)
973
            continue;
974
        sum = celt_inner_prod(x_lp, y+i, len>>1);
975
        xcorr[i] = FFMAX(-1, sum);
976
    }
977
978
    find_best_pitch(xcorr, y, len>>1, max_pitch>>1, best_pitch);
979
980
    /* Refine by pseudo-interpolation */
981
    if (best_pitch[0] > 0 && best_pitch[0] < (max_pitch >> 1) - 1) {
982
        float a, b, c;
983
984
        a = xcorr[best_pitch[0] - 1];
985
        b = xcorr[best_pitch[0]];
986
        c = xcorr[best_pitch[0] + 1];
987
        if (c - a > .7f * (b - a))
988
            offset = 1;
989
        else if (a - c > .7f * (b-c))
990
            offset = -1;
991
        else
992
            offset = 0;
993
    } else {
994
        offset = 0;
995
    }
996
997
    *pitch = 2 * best_pitch[0] - offset;
998
}
999
1000
static void dct(AudioRNNContext *s, float *out, const float *in)
1001
{
1002
    for (int i = 0; i < NB_BANDS; i++) {
1003
        float sum;
1004
1005
        sum = s->fdsp->scalarproduct_float(in, s->dct_table[i], FFALIGN(NB_BANDS, 4));
1006
        out[i] = sum * sqrtf(2.f / 22);
1007
    }
1008
}
1009
1010
static int compute_frame_features(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, AVComplexFloat *P,
1011
                                  float *Ex, float *Ep, float *Exp, float *features, const float *in)
1012
{
1013
    float E = 0;
1014
    float *ceps_0, *ceps_1, *ceps_2;
1015
    float spec_variability = 0;
1016
    LOCAL_ALIGNED_32(float, Ly, [NB_BANDS]);
1017
    LOCAL_ALIGNED_32(float, p, [WINDOW_SIZE]);
1018
    float pitch_buf[PITCH_BUF_SIZE>>1];
1019
    int pitch_index;
1020
    float gain;
1021
    float *(pre[1]);
1022
    float tmp[NB_BANDS];
1023
    float follow, logMax;
1024
1025
    frame_analysis(s, st, X, Ex, in);
1026
    RNN_MOVE(st->pitch_buf, &st->pitch_buf[FRAME_SIZE], PITCH_BUF_SIZE-FRAME_SIZE);
1027
    RNN_COPY(&st->pitch_buf[PITCH_BUF_SIZE-FRAME_SIZE], in, FRAME_SIZE);
1028
    pre[0] = &st->pitch_buf[0];
1029
    pitch_downsample(pre, pitch_buf, PITCH_BUF_SIZE, 1);
1030
    pitch_search(pitch_buf+(PITCH_MAX_PERIOD>>1), pitch_buf, PITCH_FRAME_SIZE,
1031
            PITCH_MAX_PERIOD-3*PITCH_MIN_PERIOD, &pitch_index);
1032
    pitch_index = PITCH_MAX_PERIOD-pitch_index;
1033
1034
    gain = remove_doubling(pitch_buf, PITCH_MAX_PERIOD, PITCH_MIN_PERIOD,
1035
            PITCH_FRAME_SIZE, &pitch_index, st->last_period, st->last_gain);
1036
    st->last_period = pitch_index;
1037
    st->last_gain = gain;
1038
1039
    for (int i = 0; i < WINDOW_SIZE; i++)
1040
        p[i] = st->pitch_buf[PITCH_BUF_SIZE-WINDOW_SIZE-pitch_index+i];
1041
1042
    s->fdsp->vector_fmul(p, p, s->window, WINDOW_SIZE);
1043
    forward_transform(st, P, p);
1044
    compute_band_energy(Ep, P);
1045
    compute_band_corr(Exp, X, P);
1046
1047
    for (int i = 0; i < NB_BANDS; i++)
1048
        Exp[i] = Exp[i] / sqrtf(.001f+Ex[i]*Ep[i]);
1049
1050
    dct(s, tmp, Exp);
1051
1052
    for (int i = 0; i < NB_DELTA_CEPS; i++)
1053
        features[NB_BANDS+2*NB_DELTA_CEPS+i] = tmp[i];
1054
1055
    features[NB_BANDS+2*NB_DELTA_CEPS] -= 1.3;
1056
    features[NB_BANDS+2*NB_DELTA_CEPS+1] -= 0.9;
1057
    features[NB_BANDS+3*NB_DELTA_CEPS] = .01*(pitch_index-300);
1058
    logMax = -2;
1059
    follow = -2;
1060
1061
    for (int i = 0; i < NB_BANDS; i++) {
1062
        Ly[i] = log10f(1e-2f + Ex[i]);
1063
        Ly[i] = FFMAX(logMax-7, FFMAX(follow-1.5, Ly[i]));
1064
        logMax = FFMAX(logMax, Ly[i]);
1065
        follow = FFMAX(follow-1.5, Ly[i]);
1066
        E += Ex[i];
1067
    }
1068
1069
    if (E < 0.04f) {
1070
        /* If there's no audio, avoid messing up the state. */
1071
        RNN_CLEAR(features, NB_FEATURES);
1072
        return 1;
1073
    }
1074
1075
    dct(s, features, Ly);
1076
    features[0] -= 12;
1077
    features[1] -= 4;
1078
    ceps_0 = st->cepstral_mem[st->memid];
1079
    ceps_1 = (st->memid < 1) ? st->cepstral_mem[CEPS_MEM+st->memid-1] : st->cepstral_mem[st->memid-1];
1080
    ceps_2 = (st->memid < 2) ? st->cepstral_mem[CEPS_MEM+st->memid-2] : st->cepstral_mem[st->memid-2];
1081
1082
    for (int i = 0; i < NB_BANDS; i++)
1083
        ceps_0[i] = features[i];
1084
1085
    st->memid++;
1086
    for (int i = 0; i < NB_DELTA_CEPS; i++) {
1087
        features[i] = ceps_0[i] + ceps_1[i] + ceps_2[i];
1088
        features[NB_BANDS+i] = ceps_0[i] - ceps_2[i];
1089
        features[NB_BANDS+NB_DELTA_CEPS+i] =  ceps_0[i] - 2*ceps_1[i] + ceps_2[i];
1090
    }
1091
    /* Spectral variability features. */
1092
    if (st->memid == CEPS_MEM)
1093
        st->memid = 0;
1094
1095
    for (int i = 0; i < CEPS_MEM; i++) {
1096
        float mindist = 1e15f;
1097
        for (int j = 0; j < CEPS_MEM; j++) {
1098
            float dist = 0.f;
1099
            for (int k = 0; k < NB_BANDS; k++) {
1100
                float tmp;
1101
1102
                tmp = st->cepstral_mem[i][k] - st->cepstral_mem[j][k];
1103
                dist += tmp*tmp;
1104
            }
1105
1106
            if (j != i)
1107
                mindist = FFMIN(mindist, dist);
1108
        }
1109
1110
        spec_variability += mindist;
1111
    }
1112
1113
    features[NB_BANDS+3*NB_DELTA_CEPS+1] = spec_variability/CEPS_MEM-2.1;
1114
1115
    return 0;
1116
}
1117
1118
static void interp_band_gain(float *g, const float *bandE)
1119
{
1120
    memset(g, 0, sizeof(*g) * FREQ_SIZE);
1121
1122
    for (int i = 0; i < NB_BANDS - 1; i++) {
1123
        const int band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
1124
1125
        for (int j = 0; j < band_size; j++) {
1126
            float frac = (float)j / band_size;
1127
1128
            g[(eband5ms[i] << FRAME_SIZE_SHIFT) + j] = (1.f - frac) * bandE[i] + frac * bandE[i + 1];
1129
        }
1130
    }
1131
}
1132
1133
static void pitch_filter(AVComplexFloat *X, const AVComplexFloat *P, const float *Ex, const float *Ep,
1134
                         const float *Exp, const float *g)
1135
{
1136
    float newE[NB_BANDS];
1137
    float r[NB_BANDS];
1138
    float norm[NB_BANDS];
1139
    float rf[FREQ_SIZE] = {0};
1140
    float normf[FREQ_SIZE]={0};
1141
1142
    for (int i = 0; i < NB_BANDS; i++) {
1143
        if (Exp[i]>g[i]) r[i] = 1;
1144
        else r[i] = SQUARE(Exp[i])*(1-SQUARE(g[i]))/(.001 + SQUARE(g[i])*(1-SQUARE(Exp[i])));
1145
        r[i]  = sqrtf(av_clipf(r[i], 0, 1));
1146
        r[i] *= sqrtf(Ex[i]/(1e-8+Ep[i]));
1147
    }
1148
    interp_band_gain(rf, r);
1149
    for (int i = 0; i < FREQ_SIZE; i++) {
1150
        X[i].re += rf[i]*P[i].re;
1151
        X[i].im += rf[i]*P[i].im;
1152
    }
1153
    compute_band_energy(newE, X);
1154
    for (int i = 0; i < NB_BANDS; i++) {
1155
        norm[i] = sqrtf(Ex[i] / (1e-8+newE[i]));
1156
    }
1157
    interp_band_gain(normf, norm);
1158
    for (int i = 0; i < FREQ_SIZE; i++) {
1159
        X[i].re *= normf[i];
1160
        X[i].im *= normf[i];
1161
    }
1162
}
1163
1164
static const float tansig_table[201] = {
1165
    0.000000f, 0.039979f, 0.079830f, 0.119427f, 0.158649f,
1166
    0.197375f, 0.235496f, 0.272905f, 0.309507f, 0.345214f,
1167
    0.379949f, 0.413644f, 0.446244f, 0.477700f, 0.507977f,
1168
    0.537050f, 0.564900f, 0.591519f, 0.616909f, 0.641077f,
1169
    0.664037f, 0.685809f, 0.706419f, 0.725897f, 0.744277f,
1170
    0.761594f, 0.777888f, 0.793199f, 0.807569f, 0.821040f,
1171
    0.833655f, 0.845456f, 0.856485f, 0.866784f, 0.876393f,
1172
    0.885352f, 0.893698f, 0.901468f, 0.908698f, 0.915420f,
1173
    0.921669f, 0.927473f, 0.932862f, 0.937863f, 0.942503f,
1174
    0.946806f, 0.950795f, 0.954492f, 0.957917f, 0.961090f,
1175
    0.964028f, 0.966747f, 0.969265f, 0.971594f, 0.973749f,
1176
    0.975743f, 0.977587f, 0.979293f, 0.980869f, 0.982327f,
1177
    0.983675f, 0.984921f, 0.986072f, 0.987136f, 0.988119f,
1178
    0.989027f, 0.989867f, 0.990642f, 0.991359f, 0.992020f,
1179
    0.992631f, 0.993196f, 0.993718f, 0.994199f, 0.994644f,
1180
    0.995055f, 0.995434f, 0.995784f, 0.996108f, 0.996407f,
1181
    0.996682f, 0.996937f, 0.997172f, 0.997389f, 0.997590f,
1182
    0.997775f, 0.997946f, 0.998104f, 0.998249f, 0.998384f,
1183
    0.998508f, 0.998623f, 0.998728f, 0.998826f, 0.998916f,
1184
    0.999000f, 0.999076f, 0.999147f, 0.999213f, 0.999273f,
1185
    0.999329f, 0.999381f, 0.999428f, 0.999472f, 0.999513f,
1186
    0.999550f, 0.999585f, 0.999617f, 0.999646f, 0.999673f,
1187
    0.999699f, 0.999722f, 0.999743f, 0.999763f, 0.999781f,
1188
    0.999798f, 0.999813f, 0.999828f, 0.999841f, 0.999853f,
1189
    0.999865f, 0.999875f, 0.999885f, 0.999893f, 0.999902f,
1190
    0.999909f, 0.999916f, 0.999923f, 0.999929f, 0.999934f,
1191
    0.999939f, 0.999944f, 0.999948f, 0.999952f, 0.999956f,
1192
    0.999959f, 0.999962f, 0.999965f, 0.999968f, 0.999970f,
1193
    0.999973f, 0.999975f, 0.999977f, 0.999978f, 0.999980f,
1194
    0.999982f, 0.999983f, 0.999984f, 0.999986f, 0.999987f,
1195
    0.999988f, 0.999989f, 0.999990f, 0.999990f, 0.999991f,
1196
    0.999992f, 0.999992f, 0.999993f, 0.999994f, 0.999994f,
1197
    0.999994f, 0.999995f, 0.999995f, 0.999996f, 0.999996f,
1198
    0.999996f, 0.999997f, 0.999997f, 0.999997f, 0.999997f,
1199
    0.999997f, 0.999998f, 0.999998f, 0.999998f, 0.999998f,
1200
    0.999998f, 0.999998f, 0.999999f, 0.999999f, 0.999999f,
1201
    0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1202
    0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1203
    1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1204
    1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1205
    1.000000f,
1206
};
1207
1208
static inline float tansig_approx(float x)
1209
{
1210
    float y, dy;
1211
    float sign=1;
1212
    int i;
1213
1214
    /* Tests are reversed to catch NaNs */
1215
    if (!(x<8))
1216
        return 1;
1217
    if (!(x>-8))
1218
        return -1;
1219
    /* Another check in case of -ffast-math */
1220
1221
    if (isnan(x))
1222
       return 0;
1223
1224
    if (x < 0) {
1225
       x=-x;
1226
       sign=-1;
1227
    }
1228
    i = (int)floor(.5f+25*x);
1229
    x -= .04f*i;
1230
    y = tansig_table[i];
1231
    dy = 1-y*y;
1232
    y = y + x*dy*(1 - y*x);
1233
    return sign*y;
1234
}
1235
1236
static inline float sigmoid_approx(float x)
1237
{
1238
    return .5f + .5f*tansig_approx(.5f*x);
1239
}
1240
1241
static void compute_dense(const DenseLayer *layer, float *output, const float *input)
1242
{
1243
    const int N = layer->nb_neurons, M = layer->nb_inputs, stride = N;
1244
1245
    for (int i = 0; i < N; i++) {
1246
        /* Compute update gate. */
1247
        float sum = layer->bias[i];
1248
1249
        for (int j = 0; j < M; j++)
1250
            sum += layer->input_weights[j * stride + i] * input[j];
1251
1252
        output[i] = WEIGHTS_SCALE * sum;
1253
    }
1254
1255
    if (layer->activation == ACTIVATION_SIGMOID) {
1256
        for (int i = 0; i < N; i++)
1257
            output[i] = sigmoid_approx(output[i]);
1258
    } else if (layer->activation == ACTIVATION_TANH) {
1259
        for (int i = 0; i < N; i++)
1260
            output[i] = tansig_approx(output[i]);
1261
    } else if (layer->activation == ACTIVATION_RELU) {
1262
        for (int i = 0; i < N; i++)
1263
            output[i] = FFMAX(0, output[i]);
1264
    } else {
1265
        av_assert0(0);
1266
    }
1267
}
1268
1269
static void compute_gru(AudioRNNContext *s, const GRULayer *gru, float *state, const float *input)
1270
{
1271
    LOCAL_ALIGNED_32(float, z, [MAX_NEURONS]);
1272
    LOCAL_ALIGNED_32(float, r, [MAX_NEURONS]);
1273
    LOCAL_ALIGNED_32(float, h, [MAX_NEURONS]);
1274
    const int M = gru->nb_inputs;
1275
    const int N = gru->nb_neurons;
1276
    const int AN = FFALIGN(N, 4);
1277
    const int AM = FFALIGN(M, 4);
1278
    const int stride = 3 * AN, istride = 3 * AM;
1279
1280
    for (int i = 0; i < N; i++) {
1281
        /* Compute update gate. */
1282
        float sum = gru->bias[i];
1283
1284
        sum += s->fdsp->scalarproduct_float(gru->input_weights + i * istride, input, AM);
1285
        sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + i * stride, state, AN);
1286
        z[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1287
    }
1288
1289
    for (int i = 0; i < N; i++) {
1290
        /* Compute reset gate. */
1291
        float sum = gru->bias[N + i];
1292
1293
        sum += s->fdsp->scalarproduct_float(gru->input_weights + AM + i * istride, input, AM);
1294
        sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + AN + i * stride, state, AN);
1295
        r[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1296
    }
1297
1298
    for (int i = 0; i < N; i++) {
1299
        /* Compute output. */
1300
        float sum = gru->bias[2 * N + i];
1301
1302
        sum += s->fdsp->scalarproduct_float(gru->input_weights + 2 * AM + i * istride, input, AM);
1303
        for (int j = 0; j < N; j++)
1304
            sum += gru->recurrent_weights[2 * AN + i * stride + j] * state[j] * r[j];
1305
1306
        if (gru->activation == ACTIVATION_SIGMOID)
1307
            sum = sigmoid_approx(WEIGHTS_SCALE * sum);
1308
        else if (gru->activation == ACTIVATION_TANH)
1309
            sum = tansig_approx(WEIGHTS_SCALE * sum);
1310
        else if (gru->activation == ACTIVATION_RELU)
1311
            sum = FFMAX(0, WEIGHTS_SCALE * sum);
1312
        else
1313
            av_assert0(0);
1314
        h[i] = z[i] * state[i] + (1.f - z[i]) * sum;
1315
    }
1316
1317
    RNN_COPY(state, h, N);
1318
}
1319
1320
#define INPUT_SIZE 42
1321
1322
static void compute_rnn(AudioRNNContext *s, RNNState *rnn, float *gains, float *vad, const float *input)
1323
{
1324
    LOCAL_ALIGNED_32(float, dense_out,     [MAX_NEURONS]);
1325
    LOCAL_ALIGNED_32(float, noise_input,   [MAX_NEURONS * 3]);
1326
    LOCAL_ALIGNED_32(float, denoise_input, [MAX_NEURONS * 3]);
1327
1328
    compute_dense(rnn->model->input_dense, dense_out, input);
1329
    compute_gru(s, rnn->model->vad_gru, rnn->vad_gru_state, dense_out);
1330
    compute_dense(rnn->model->vad_output, vad, rnn->vad_gru_state);
1331
1332
    memcpy(noise_input, dense_out, rnn->model->input_dense_size * sizeof(float));
1333
    memcpy(noise_input + rnn->model->input_dense_size,
1334
           rnn->vad_gru_state, rnn->model->vad_gru_size * sizeof(float));
1335
    memcpy(noise_input + rnn->model->input_dense_size + rnn->model->vad_gru_size,
1336
           input, INPUT_SIZE * sizeof(float));
1337
1338
    compute_gru(s, rnn->model->noise_gru, rnn->noise_gru_state, noise_input);
1339
1340
    memcpy(denoise_input, rnn->vad_gru_state, rnn->model->vad_gru_size * sizeof(float));
1341
    memcpy(denoise_input + rnn->model->vad_gru_size,
1342
           rnn->noise_gru_state, rnn->model->noise_gru_size * sizeof(float));
1343
    memcpy(denoise_input + rnn->model->vad_gru_size + rnn->model->noise_gru_size,
1344
           input, INPUT_SIZE * sizeof(float));
1345
1346
    compute_gru(s, rnn->model->denoise_gru, rnn->denoise_gru_state, denoise_input);
1347
    compute_dense(rnn->model->denoise_output, gains, rnn->denoise_gru_state);
1348
}
1349
1350
static float rnnoise_channel(AudioRNNContext *s, DenoiseState *st, float *out, const float *in,
1351
                             int disabled)
1352
{
1353
    AVComplexFloat X[FREQ_SIZE];
1354
    AVComplexFloat P[WINDOW_SIZE];
1355
    float x[FRAME_SIZE];
1356
    float Ex[NB_BANDS], Ep[NB_BANDS];
1357
    LOCAL_ALIGNED_32(float, Exp, [NB_BANDS]);
1358
    float features[NB_FEATURES];
1359
    float g[NB_BANDS];
1360
    float gf[FREQ_SIZE];
1361
    float vad_prob = 0;
1362
    float *history = st->history;
1363
    static const float a_hp[2] = {-1.99599, 0.99600};
1364
    static const float b_hp[2] = {-2, 1};
1365
    int silence;
1366
1367
    biquad(x, st->mem_hp_x, in, b_hp, a_hp, FRAME_SIZE);
1368
    silence = compute_frame_features(s, st, X, P, Ex, Ep, Exp, features, x);
1369
1370
    if (!silence && !disabled) {
1371
        compute_rnn(s, &st->rnn, g, &vad_prob, features);
1372
        pitch_filter(X, P, Ex, Ep, Exp, g);
1373
        for (int i = 0; i < NB_BANDS; i++) {
1374
            float alpha = .6f;
1375
1376
            g[i] = FFMAX(g[i], alpha * st->lastg[i]);
1377
            st->lastg[i] = g[i];
1378
        }
1379
1380
        interp_band_gain(gf, g);
1381
1382
        for (int i = 0; i < FREQ_SIZE; i++) {
1383
            X[i].re *= gf[i];
1384
            X[i].im *= gf[i];
1385
        }
1386
    }
1387
1388
    frame_synthesis(s, st, out, X);
1389
    memcpy(history, in, FRAME_SIZE * sizeof(*history));
1390
1391
    return vad_prob;
1392
}
1393
1394
typedef struct ThreadData {
1395
    AVFrame *in, *out;
1396
} ThreadData;
1397
1398
static int rnnoise_channels(AVFilterContext *ctx, void *arg, int jobnr, int nb_jobs)
1399
{
1400
    AudioRNNContext *s = ctx->priv;
1401
    ThreadData *td = arg;
1402
    AVFrame *in = td->in;
1403
    AVFrame *out = td->out;
1404
    const int start = (out->channels * jobnr) / nb_jobs;
1405
    const int end = (out->channels * (jobnr+1)) / nb_jobs;
1406
1407
    for (int ch = start; ch < end; ch++) {
1408
        rnnoise_channel(s, &s->st[ch],
1409
                        (float *)out->extended_data[ch],
1410
                        (const float *)in->extended_data[ch],
1411
                        ctx->is_disabled);
1412
    }
1413
1414
    return 0;
1415
}
1416
1417
static int filter_frame(AVFilterLink *inlink, AVFrame *in)
1418
{
1419
    AVFilterContext *ctx = inlink->dst;
1420
    AVFilterLink *outlink = ctx->outputs[0];
1421
    AVFrame *out = NULL;
1422
    ThreadData td;
1423
1424
    out = ff_get_audio_buffer(outlink, FRAME_SIZE);
1425
    if (!out) {
1426
        av_frame_free(&in);
1427
        return AVERROR(ENOMEM);
1428
    }
1429
    out->pts = in->pts;
1430
1431
    td.in = in; td.out = out;
1432
    ctx->internal->execute(ctx, rnnoise_channels, &td, NULL, FFMIN(outlink->channels,
1433
                                                                   ff_filter_get_nb_threads(ctx)));
1434
1435
    av_frame_free(&in);
1436
    return ff_filter_frame(outlink, out);
1437
}
1438
1439
static int activate(AVFilterContext *ctx)
1440
{
1441
    AVFilterLink *inlink = ctx->inputs[0];
1442
    AVFilterLink *outlink = ctx->outputs[0];
1443
    AVFrame *in = NULL;
1444
    int ret;
1445
1446
    FF_FILTER_FORWARD_STATUS_BACK(outlink, inlink);
1447
1448
    ret = ff_inlink_consume_samples(inlink, FRAME_SIZE, FRAME_SIZE, &in);
1449
    if (ret < 0)
1450
        return ret;
1451
1452
    if (ret > 0)
1453
        return filter_frame(inlink, in);
1454
1455
    FF_FILTER_FORWARD_STATUS(inlink, outlink);
1456
    FF_FILTER_FORWARD_WANTED(outlink, inlink);
1457
1458
    return FFERROR_NOT_READY;
1459
}
1460
1461
static av_cold int init(AVFilterContext *ctx)
1462
{
1463
    AudioRNNContext *s = ctx->priv;
1464
    FILE *f;
1465
1466
    s->fdsp = avpriv_float_dsp_alloc(0);
1467
    if (!s->fdsp)
1468
        return AVERROR(ENOMEM);
1469
1470
    if (!s->model_name)
1471
        return AVERROR(EINVAL);
1472
    f = av_fopen_utf8(s->model_name, "r");
1473
    if (!f)
1474
        return AVERROR(EINVAL);
1475
1476
    s->model = rnnoise_model_from_file(f);
1477
    fclose(f);
1478
    if (!s->model)
1479
        return AVERROR(EINVAL);
1480
1481
    for (int i = 0; i < FRAME_SIZE; i++) {
1482
        s->window[i] = sin(.5*M_PI*sin(.5*M_PI*(i+.5)/FRAME_SIZE) * sin(.5*M_PI*(i+.5)/FRAME_SIZE));
1483
        s->window[WINDOW_SIZE - 1 - i] = s->window[i];
1484
    }
1485
1486
    for (int i = 0; i < NB_BANDS; i++) {
1487
        for (int j = 0; j < NB_BANDS; j++) {
1488
            s->dct_table[j][i] = cosf((i + .5f) * j * M_PI / NB_BANDS);
1489
            if (j == 0)
1490
                s->dct_table[j][i] *= sqrtf(.5);
1491
        }
1492
    }
1493
1494
    return 0;
1495
}
1496
1497
static av_cold void uninit(AVFilterContext *ctx)
1498
{
1499
    AudioRNNContext *s = ctx->priv;
1500
1501
    av_freep(&s->fdsp);
1502
    rnnoise_model_free(s->model);
1503
    s->model = NULL;
1504
1505
    if (s->st) {
1506
        for (int ch = 0; ch < s->channels; ch++) {
1507
            av_freep(&s->st[ch].rnn.vad_gru_state);
1508
            av_freep(&s->st[ch].rnn.noise_gru_state);
1509
            av_freep(&s->st[ch].rnn.denoise_gru_state);
1510
            av_tx_uninit(&s->st[ch].tx);
1511
            av_tx_uninit(&s->st[ch].txi);
1512
        }
1513
    }
1514
    av_freep(&s->st);
1515
}
1516
1517
static const AVFilterPad inputs[] = {
1518
    {
1519
        .name         = "default",
1520
        .type         = AVMEDIA_TYPE_AUDIO,
1521
        .config_props = config_input,
1522
    },
1523
    { NULL }
1524
};
1525
1526
static const AVFilterPad outputs[] = {
1527
    {
1528
        .name          = "default",
1529
        .type          = AVMEDIA_TYPE_AUDIO,
1530
    },
1531
    { NULL }
1532
};
1533
1534
#define OFFSET(x) offsetof(AudioRNNContext, x)
1535
#define AF AV_OPT_FLAG_AUDIO_PARAM|AV_OPT_FLAG_FILTERING_PARAM
1536
1537
static const AVOption arnndn_options[] = {
1538
    { "model", "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1539
    { "m",     "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1540
    { "mix",   "set output vs input mix", OFFSET(mix), AV_OPT_TYPE_FLOAT, {.dbl=1.0},-1, 1, AF },
1541
    { NULL }
1542
};
1543
1544
AVFILTER_DEFINE_CLASS(arnndn);
1545
1546
AVFilter ff_af_arnndn = {
1547
    .name          = "arnndn",
1548
    .description   = NULL_IF_CONFIG_SMALL("Reduce noise from speech using Recurrent Neural Networks."),
1549
    .query_formats = query_formats,
1550
    .priv_size     = sizeof(AudioRNNContext),
1551
    .priv_class    = &arnndn_class,
1552
    .activate      = activate,
1553
    .init          = init,
1554
    .uninit        = uninit,
1555
    .inputs        = inputs,
1556
    .outputs       = outputs,
1557
    .flags         = AVFILTER_FLAG_SUPPORT_TIMELINE_INTERNAL |
1558
                     AVFILTER_FLAG_SLICE_THREADS,
1559
};