GCC Code Coverage Report
Directory: ../../../ffmpeg/ Exec Total Coverage
File: src/libavfilter/af_arnndn.c Lines: 0 704 0.0 %
Date: 2020-09-25 23:16:12 Branches: 0 590 0.0 %

Line Branch Exec Source
1
/*
2
 * Copyright (c) 2018 Gregor Richards
3
 * Copyright (c) 2017 Mozilla
4
 * Copyright (c) 2005-2009 Xiph.Org Foundation
5
 * Copyright (c) 2007-2008 CSIRO
6
 * Copyright (c) 2008-2011 Octasic Inc.
7
 * Copyright (c) Jean-Marc Valin
8
 * Copyright (c) 2019 Paul B Mahol
9
 *
10
 * Redistribution and use in source and binary forms, with or without
11
 * modification, are permitted provided that the following conditions
12
 * are met:
13
 *
14
 * - Redistributions of source code must retain the above copyright
15
 *   notice, this list of conditions and the following disclaimer.
16
 *
17
 * - Redistributions in binary form must reproduce the above copyright
18
 *   notice, this list of conditions and the following disclaimer in the
19
 *   documentation and/or other materials provided with the distribution.
20
 *
21
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22
 * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23
 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24
 * A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
25
 * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26
 * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27
 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28
 * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29
 * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30
 * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32
 */
33
34
#include <float.h>
35
36
#include "libavutil/avassert.h"
37
#include "libavutil/avstring.h"
38
#include "libavutil/float_dsp.h"
39
#include "libavutil/opt.h"
40
#include "libavutil/tx.h"
41
#include "avfilter.h"
42
#include "audio.h"
43
#include "filters.h"
44
#include "formats.h"
45
46
#define FRAME_SIZE_SHIFT 2
47
#define FRAME_SIZE (120<<FRAME_SIZE_SHIFT)
48
#define WINDOW_SIZE (2*FRAME_SIZE)
49
#define FREQ_SIZE (FRAME_SIZE + 1)
50
51
#define PITCH_MIN_PERIOD 60
52
#define PITCH_MAX_PERIOD 768
53
#define PITCH_FRAME_SIZE 960
54
#define PITCH_BUF_SIZE (PITCH_MAX_PERIOD+PITCH_FRAME_SIZE)
55
56
#define SQUARE(x) ((x)*(x))
57
58
#define NB_BANDS 22
59
60
#define CEPS_MEM 8
61
#define NB_DELTA_CEPS 6
62
63
#define NB_FEATURES (NB_BANDS+3*NB_DELTA_CEPS+2)
64
65
#define WEIGHTS_SCALE (1.f/256)
66
67
#define MAX_NEURONS 128
68
69
#define ACTIVATION_TANH    0
70
#define ACTIVATION_SIGMOID 1
71
#define ACTIVATION_RELU    2
72
73
#define Q15ONE 1.0f
74
75
typedef struct DenseLayer {
76
    const float *bias;
77
    const float *input_weights;
78
    int nb_inputs;
79
    int nb_neurons;
80
    int activation;
81
} DenseLayer;
82
83
typedef struct GRULayer {
84
    const float *bias;
85
    const float *input_weights;
86
    const float *recurrent_weights;
87
    int nb_inputs;
88
    int nb_neurons;
89
    int activation;
90
} GRULayer;
91
92
typedef struct RNNModel {
93
    int input_dense_size;
94
    const DenseLayer *input_dense;
95
96
    int vad_gru_size;
97
    const GRULayer *vad_gru;
98
99
    int noise_gru_size;
100
    const GRULayer *noise_gru;
101
102
    int denoise_gru_size;
103
    const GRULayer *denoise_gru;
104
105
    int denoise_output_size;
106
    const DenseLayer *denoise_output;
107
108
    int vad_output_size;
109
    const DenseLayer *vad_output;
110
} RNNModel;
111
112
typedef struct RNNState {
113
    float *vad_gru_state;
114
    float *noise_gru_state;
115
    float *denoise_gru_state;
116
    RNNModel *model;
117
} RNNState;
118
119
typedef struct DenoiseState {
120
    float analysis_mem[FRAME_SIZE];
121
    float cepstral_mem[CEPS_MEM][NB_BANDS];
122
    int memid;
123
    DECLARE_ALIGNED(32, float, synthesis_mem)[FRAME_SIZE];
124
    float pitch_buf[PITCH_BUF_SIZE];
125
    float pitch_enh_buf[PITCH_BUF_SIZE];
126
    float last_gain;
127
    int last_period;
128
    float mem_hp_x[2];
129
    float lastg[NB_BANDS];
130
    RNNState rnn;
131
    AVTXContext *tx, *txi;
132
    av_tx_fn tx_fn, txi_fn;
133
} DenoiseState;
134
135
typedef struct AudioRNNContext {
136
    const AVClass *class;
137
138
    char *model_name;
139
140
    int channels;
141
    DenoiseState *st;
142
143
    DECLARE_ALIGNED(32, float, window)[WINDOW_SIZE];
144
    DECLARE_ALIGNED(32, float, dct_table)[FFALIGN(NB_BANDS, 4)][FFALIGN(NB_BANDS, 4)];
145
146
    RNNModel *model;
147
148
    AVFloatDSPContext *fdsp;
149
} AudioRNNContext;
150
151
#define F_ACTIVATION_TANH       0
152
#define F_ACTIVATION_SIGMOID    1
153
#define F_ACTIVATION_RELU       2
154
155
static void rnnoise_model_free(RNNModel *model)
156
{
157
#define FREE_MAYBE(ptr) do { if (ptr) free(ptr); } while (0)
158
#define FREE_DENSE(name) do { \
159
    if (model->name) { \
160
        av_free((void *) model->name->input_weights); \
161
        av_free((void *) model->name->bias); \
162
        av_free((void *) model->name); \
163
    } \
164
    } while (0)
165
#define FREE_GRU(name) do { \
166
    if (model->name) { \
167
        av_free((void *) model->name->input_weights); \
168
        av_free((void *) model->name->recurrent_weights); \
169
        av_free((void *) model->name->bias); \
170
        av_free((void *) model->name); \
171
    } \
172
    } while (0)
173
174
    if (!model)
175
        return;
176
    FREE_DENSE(input_dense);
177
    FREE_GRU(vad_gru);
178
    FREE_GRU(noise_gru);
179
    FREE_GRU(denoise_gru);
180
    FREE_DENSE(denoise_output);
181
    FREE_DENSE(vad_output);
182
    av_free(model);
183
}
184
185
static RNNModel *rnnoise_model_from_file(FILE *f)
186
{
187
    RNNModel *ret;
188
    DenseLayer *input_dense;
189
    GRULayer *vad_gru;
190
    GRULayer *noise_gru;
191
    GRULayer *denoise_gru;
192
    DenseLayer *denoise_output;
193
    DenseLayer *vad_output;
194
    int in;
195
196
    if (fscanf(f, "rnnoise-nu model file version %d\n", &in) != 1 || in != 1)
197
        return NULL;
198
199
    ret = av_calloc(1, sizeof(RNNModel));
200
    if (!ret)
201
        return NULL;
202
203
#define ALLOC_LAYER(type, name) \
204
    name = av_calloc(1, sizeof(type)); \
205
    if (!name) { \
206
        rnnoise_model_free(ret); \
207
        return NULL; \
208
    } \
209
    ret->name = name
210
211
    ALLOC_LAYER(DenseLayer, input_dense);
212
    ALLOC_LAYER(GRULayer, vad_gru);
213
    ALLOC_LAYER(GRULayer, noise_gru);
214
    ALLOC_LAYER(GRULayer, denoise_gru);
215
    ALLOC_LAYER(DenseLayer, denoise_output);
216
    ALLOC_LAYER(DenseLayer, vad_output);
217
218
#define INPUT_VAL(name) do { \
219
    if (fscanf(f, "%d", &in) != 1 || in < 0 || in > 128) { \
220
        rnnoise_model_free(ret); \
221
        return NULL; \
222
    } \
223
    name = in; \
224
    } while (0)
225
226
#define INPUT_ACTIVATION(name) do { \
227
    int activation; \
228
    INPUT_VAL(activation); \
229
    switch (activation) { \
230
    case F_ACTIVATION_SIGMOID: \
231
        name = ACTIVATION_SIGMOID; \
232
        break; \
233
    case F_ACTIVATION_RELU: \
234
        name = ACTIVATION_RELU; \
235
        break; \
236
    default: \
237
        name = ACTIVATION_TANH; \
238
    } \
239
    } while (0)
240
241
#define INPUT_ARRAY(name, len) do { \
242
    float *values = av_calloc((len), sizeof(float)); \
243
    if (!values) { \
244
        rnnoise_model_free(ret); \
245
        return NULL; \
246
    } \
247
    name = values; \
248
    for (int i = 0; i < (len); i++) { \
249
        if (fscanf(f, "%d", &in) != 1) { \
250
            rnnoise_model_free(ret); \
251
            return NULL; \
252
        } \
253
        values[i] = in; \
254
    } \
255
    } while (0)
256
257
#define INPUT_ARRAY3(name, len0, len1, len2) do { \
258
    float *values = av_calloc(FFALIGN((len0), 4) * FFALIGN((len1), 4) * (len2), sizeof(float)); \
259
    if (!values) { \
260
        rnnoise_model_free(ret); \
261
        return NULL; \
262
    } \
263
    name = values; \
264
    for (int k = 0; k < (len0); k++) { \
265
        for (int i = 0; i < (len2); i++) { \
266
            for (int j = 0; j < (len1); j++) { \
267
                if (fscanf(f, "%d", &in) != 1) { \
268
                    rnnoise_model_free(ret); \
269
                    return NULL; \
270
                } \
271
                values[j * (len2) * FFALIGN((len0), 4) + i * FFALIGN((len0), 4) + k] = in; \
272
            } \
273
        } \
274
    } \
275
    } while (0)
276
277
#define INPUT_DENSE(name) do { \
278
    INPUT_VAL(name->nb_inputs); \
279
    INPUT_VAL(name->nb_neurons); \
280
    ret->name ## _size = name->nb_neurons; \
281
    INPUT_ACTIVATION(name->activation); \
282
    INPUT_ARRAY(name->input_weights, name->nb_inputs * name->nb_neurons); \
283
    INPUT_ARRAY(name->bias, name->nb_neurons); \
284
    } while (0)
285
286
#define INPUT_GRU(name) do { \
287
    INPUT_VAL(name->nb_inputs); \
288
    INPUT_VAL(name->nb_neurons); \
289
    ret->name ## _size = name->nb_neurons; \
290
    INPUT_ACTIVATION(name->activation); \
291
    INPUT_ARRAY3(name->input_weights, name->nb_inputs, name->nb_neurons, 3); \
292
    INPUT_ARRAY3(name->recurrent_weights, name->nb_neurons, name->nb_neurons, 3); \
293
    INPUT_ARRAY(name->bias, name->nb_neurons * 3); \
294
    } while (0)
295
296
    INPUT_DENSE(input_dense);
297
    INPUT_GRU(vad_gru);
298
    INPUT_GRU(noise_gru);
299
    INPUT_GRU(denoise_gru);
300
    INPUT_DENSE(denoise_output);
301
    INPUT_DENSE(vad_output);
302
303
    if (vad_output->nb_neurons != 1) {
304
        rnnoise_model_free(ret);
305
        return NULL;
306
    }
307
308
    return ret;
309
}
310
311
static int query_formats(AVFilterContext *ctx)
312
{
313
    AVFilterFormats *formats = NULL;
314
    AVFilterChannelLayouts *layouts = NULL;
315
    static const enum AVSampleFormat sample_fmts[] = {
316
        AV_SAMPLE_FMT_FLTP,
317
        AV_SAMPLE_FMT_NONE
318
    };
319
    int ret, sample_rates[] = { 48000, -1 };
320
321
    formats = ff_make_format_list(sample_fmts);
322
    if (!formats)
323
        return AVERROR(ENOMEM);
324
    ret = ff_set_common_formats(ctx, formats);
325
    if (ret < 0)
326
        return ret;
327
328
    layouts = ff_all_channel_counts();
329
    if (!layouts)
330
        return AVERROR(ENOMEM);
331
332
    ret = ff_set_common_channel_layouts(ctx, layouts);
333
    if (ret < 0)
334
        return ret;
335
336
    formats = ff_make_format_list(sample_rates);
337
    if (!formats)
338
        return AVERROR(ENOMEM);
339
    return ff_set_common_samplerates(ctx, formats);
340
}
341
342
static int config_input(AVFilterLink *inlink)
343
{
344
    AVFilterContext *ctx = inlink->dst;
345
    AudioRNNContext *s = ctx->priv;
346
    int ret;
347
348
    s->channels = inlink->channels;
349
350
    s->st = av_calloc(s->channels, sizeof(DenoiseState));
351
    if (!s->st)
352
        return AVERROR(ENOMEM);
353
354
    for (int i = 0; i < s->channels; i++) {
355
        DenoiseState *st = &s->st[i];
356
357
        st->rnn.model = s->model;
358
        st->rnn.vad_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->vad_gru_size, 16));
359
        st->rnn.noise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->noise_gru_size, 16));
360
        st->rnn.denoise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->denoise_gru_size, 16));
361
        if (!st->rnn.vad_gru_state ||
362
            !st->rnn.noise_gru_state ||
363
            !st->rnn.denoise_gru_state)
364
            return AVERROR(ENOMEM);
365
366
        ret = av_tx_init(&st->tx, &st->tx_fn, AV_TX_FLOAT_FFT, 0, WINDOW_SIZE, NULL, 0);
367
        if (ret < 0)
368
            return ret;
369
370
        ret = av_tx_init(&st->txi, &st->txi_fn, AV_TX_FLOAT_FFT, 1, WINDOW_SIZE, NULL, 0);
371
        if (ret < 0)
372
            return ret;
373
    }
374
375
    return 0;
376
}
377
378
static void biquad(float *y, float mem[2], const float *x,
379
                   const float *b, const float *a, int N)
380
{
381
    for (int i = 0; i < N; i++) {
382
        float xi, yi;
383
384
        xi = x[i];
385
        yi = x[i] + mem[0];
386
        mem[0] = mem[1] + (b[0]*xi - a[0]*yi);
387
        mem[1] = (b[1]*xi - a[1]*yi);
388
        y[i] = yi;
389
    }
390
}
391
392
#define RNN_MOVE(dst, src, n) (memmove((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
393
#define RNN_CLEAR(dst, n) (memset((dst), 0, (n)*sizeof(*(dst))))
394
#define RNN_COPY(dst, src, n) (memcpy((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
395
396
static void forward_transform(DenoiseState *st, AVComplexFloat *out, const float *in)
397
{
398
    AVComplexFloat x[WINDOW_SIZE];
399
    AVComplexFloat y[WINDOW_SIZE];
400
401
    for (int i = 0; i < WINDOW_SIZE; i++) {
402
        x[i].re = in[i];
403
        x[i].im = 0;
404
    }
405
406
    st->tx_fn(st->tx, y, x, sizeof(float));
407
408
    RNN_COPY(out, y, FREQ_SIZE);
409
}
410
411
static void inverse_transform(DenoiseState *st, float *out, const AVComplexFloat *in)
412
{
413
    AVComplexFloat x[WINDOW_SIZE];
414
    AVComplexFloat y[WINDOW_SIZE];
415
416
    RNN_COPY(x, in, FREQ_SIZE);
417
418
    for (int i = FREQ_SIZE; i < WINDOW_SIZE; i++) {
419
        x[i].re =  x[WINDOW_SIZE - i].re;
420
        x[i].im = -x[WINDOW_SIZE - i].im;
421
    }
422
423
    st->txi_fn(st->txi, y, x, sizeof(float));
424
425
    for (int i = 0; i < WINDOW_SIZE; i++)
426
        out[i] = y[i].re / WINDOW_SIZE;
427
}
428
429
static const uint8_t eband5ms[] = {
430
/*0  200 400 600 800  1k 1.2 1.4 1.6  2k 2.4 2.8 3.2  4k 4.8 5.6 6.8  8k 9.6 12k 15.6 20k*/
431
  0,  1,  2,  3,  4,   5, 6,  7,  8,  10, 12, 14, 16, 20, 24, 28, 34, 40, 48, 60, 78, 100
432
};
433
434
static void compute_band_energy(float *bandE, const AVComplexFloat *X)
435
{
436
    float sum[NB_BANDS] = {0};
437
438
    for (int i = 0; i < NB_BANDS - 1; i++) {
439
        int band_size;
440
441
        band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
442
        for (int j = 0; j < band_size; j++) {
443
            float tmp, frac = (float)j / band_size;
444
445
            tmp         = SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].re);
446
            tmp        += SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].im);
447
            sum[i]     += (1.f - frac) * tmp;
448
            sum[i + 1] +=        frac  * tmp;
449
        }
450
    }
451
452
    sum[0] *= 2;
453
    sum[NB_BANDS - 1] *= 2;
454
455
    for (int i = 0; i < NB_BANDS; i++)
456
        bandE[i] = sum[i];
457
}
458
459
static void compute_band_corr(float *bandE, const AVComplexFloat *X, const AVComplexFloat *P)
460
{
461
    float sum[NB_BANDS] = { 0 };
462
463
    for (int i = 0; i < NB_BANDS - 1; i++) {
464
        int band_size;
465
466
        band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
467
        for (int j = 0; j < band_size; j++) {
468
            float tmp, frac = (float)j / band_size;
469
470
            tmp  = X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re;
471
            tmp += X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im;
472
            sum[i]     += (1 - frac) * tmp;
473
            sum[i + 1] +=      frac  * tmp;
474
        }
475
    }
476
477
    sum[0] *= 2;
478
    sum[NB_BANDS-1] *= 2;
479
480
    for (int i = 0; i < NB_BANDS; i++)
481
        bandE[i] = sum[i];
482
}
483
484
static void frame_analysis(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, float *Ex, const float *in)
485
{
486
    LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
487
488
    RNN_COPY(x, st->analysis_mem, FRAME_SIZE);
489
    RNN_COPY(x + FRAME_SIZE, in, FRAME_SIZE);
490
    RNN_COPY(st->analysis_mem, in, FRAME_SIZE);
491
    s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
492
    forward_transform(st, X, x);
493
    compute_band_energy(Ex, X);
494
}
495
496
static void frame_synthesis(AudioRNNContext *s, DenoiseState *st, float *out, const AVComplexFloat *y)
497
{
498
    LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
499
500
    inverse_transform(st, x, y);
501
    s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
502
    s->fdsp->vector_fmac_scalar(x, st->synthesis_mem, 1.f, FRAME_SIZE);
503
    RNN_COPY(out, x, FRAME_SIZE);
504
    RNN_COPY(st->synthesis_mem, &x[FRAME_SIZE], FRAME_SIZE);
505
}
506
507
static inline void xcorr_kernel(const float *x, const float *y, float sum[4], int len)
508
{
509
    float y_0, y_1, y_2, y_3 = 0;
510
    int j;
511
512
    y_0 = *y++;
513
    y_1 = *y++;
514
    y_2 = *y++;
515
516
    for (j = 0; j < len - 3; j += 4) {
517
        float tmp;
518
519
        tmp = *x++;
520
        y_3 = *y++;
521
        sum[0] += tmp * y_0;
522
        sum[1] += tmp * y_1;
523
        sum[2] += tmp * y_2;
524
        sum[3] += tmp * y_3;
525
        tmp = *x++;
526
        y_0 = *y++;
527
        sum[0] += tmp * y_1;
528
        sum[1] += tmp * y_2;
529
        sum[2] += tmp * y_3;
530
        sum[3] += tmp * y_0;
531
        tmp = *x++;
532
        y_1 = *y++;
533
        sum[0] += tmp * y_2;
534
        sum[1] += tmp * y_3;
535
        sum[2] += tmp * y_0;
536
        sum[3] += tmp * y_1;
537
        tmp = *x++;
538
        y_2 = *y++;
539
        sum[0] += tmp * y_3;
540
        sum[1] += tmp * y_0;
541
        sum[2] += tmp * y_1;
542
        sum[3] += tmp * y_2;
543
    }
544
545
    if (j++ < len) {
546
        float tmp = *x++;
547
548
        y_3 = *y++;
549
        sum[0] += tmp * y_0;
550
        sum[1] += tmp * y_1;
551
        sum[2] += tmp * y_2;
552
        sum[3] += tmp * y_3;
553
    }
554
555
    if (j++ < len) {
556
        float tmp=*x++;
557
558
        y_0 = *y++;
559
        sum[0] += tmp * y_1;
560
        sum[1] += tmp * y_2;
561
        sum[2] += tmp * y_3;
562
        sum[3] += tmp * y_0;
563
    }
564
565
    if (j < len) {
566
        float tmp=*x++;
567
568
        y_1 = *y++;
569
        sum[0] += tmp * y_2;
570
        sum[1] += tmp * y_3;
571
        sum[2] += tmp * y_0;
572
        sum[3] += tmp * y_1;
573
    }
574
}
575
576
static inline float celt_inner_prod(const float *x,
577
                                    const float *y, int N)
578
{
579
    float xy = 0.f;
580
581
    for (int i = 0; i < N; i++)
582
        xy += x[i] * y[i];
583
584
    return xy;
585
}
586
587
static void celt_pitch_xcorr(const float *x, const float *y,
588
                             float *xcorr, int len, int max_pitch)
589
{
590
    int i;
591
592
    for (i = 0; i < max_pitch - 3; i += 4) {
593
        float sum[4] = { 0, 0, 0, 0};
594
595
        xcorr_kernel(x, y + i, sum, len);
596
597
        xcorr[i]     = sum[0];
598
        xcorr[i + 1] = sum[1];
599
        xcorr[i + 2] = sum[2];
600
        xcorr[i + 3] = sum[3];
601
    }
602
    /* In case max_pitch isn't a multiple of 4, do non-unrolled version. */
603
    for (; i < max_pitch; i++) {
604
        xcorr[i] = celt_inner_prod(x, y + i, len);
605
    }
606
}
607
608
static int celt_autocorr(const float *x,   /*  in: [0...n-1] samples x   */
609
                         float       *ac,  /* out: [0...lag-1] ac values */
610
                         const float *window,
611
                         int          overlap,
612
                         int          lag,
613
                         int          n)
614
{
615
    int fastN = n - lag;
616
    int shift;
617
    const float *xptr;
618
    float xx[PITCH_BUF_SIZE>>1];
619
620
    if (overlap == 0) {
621
        xptr = x;
622
    } else {
623
        for (int i = 0; i < n; i++)
624
            xx[i] = x[i];
625
        for (int i = 0; i < overlap; i++) {
626
            xx[i] = x[i] * window[i];
627
            xx[n-i-1] = x[n-i-1] * window[i];
628
        }
629
        xptr = xx;
630
    }
631
632
    shift = 0;
633
    celt_pitch_xcorr(xptr, xptr, ac, fastN, lag+1);
634
635
    for (int k = 0; k <= lag; k++) {
636
        float d = 0.f;
637
638
        for (int i = k + fastN; i < n; i++)
639
            d += xptr[i] * xptr[i-k];
640
        ac[k] += d;
641
    }
642
643
    return shift;
644
}
645
646
static void celt_lpc(float *lpc, /* out: [0...p-1] LPC coefficients      */
647
                const float *ac,   /* in:  [0...p] autocorrelation values  */
648
                          int p)
649
{
650
    float r, error = ac[0];
651
652
    RNN_CLEAR(lpc, p);
653
    if (ac[0] != 0) {
654
        for (int i = 0; i < p; i++) {
655
            /* Sum up this iteration's reflection coefficient */
656
            float rr = 0;
657
            for (int j = 0; j < i; j++)
658
                rr += (lpc[j] * ac[i - j]);
659
            rr += ac[i + 1];
660
            r = -rr/error;
661
            /*  Update LPC coefficients and total error */
662
            lpc[i] = r;
663
            for (int j = 0; j < (i + 1) >> 1; j++) {
664
                float tmp1, tmp2;
665
                tmp1 = lpc[j];
666
                tmp2 = lpc[i-1-j];
667
                lpc[j]     = tmp1 + (r*tmp2);
668
                lpc[i-1-j] = tmp2 + (r*tmp1);
669
            }
670
671
            error = error - (r * r *error);
672
            /* Bail out once we get 30 dB gain */
673
            if (error < .001f * ac[0])
674
                break;
675
        }
676
    }
677
}
678
679
static void celt_fir5(const float *x,
680
                      const float *num,
681
                      float *y,
682
                      int N,
683
                      float *mem)
684
{
685
    float num0, num1, num2, num3, num4;
686
    float mem0, mem1, mem2, mem3, mem4;
687
688
    num0 = num[0];
689
    num1 = num[1];
690
    num2 = num[2];
691
    num3 = num[3];
692
    num4 = num[4];
693
    mem0 = mem[0];
694
    mem1 = mem[1];
695
    mem2 = mem[2];
696
    mem3 = mem[3];
697
    mem4 = mem[4];
698
699
    for (int i = 0; i < N; i++) {
700
        float sum = x[i];
701
702
        sum += (num0*mem0);
703
        sum += (num1*mem1);
704
        sum += (num2*mem2);
705
        sum += (num3*mem3);
706
        sum += (num4*mem4);
707
        mem4 = mem3;
708
        mem3 = mem2;
709
        mem2 = mem1;
710
        mem1 = mem0;
711
        mem0 = x[i];
712
        y[i] = sum;
713
    }
714
715
    mem[0] = mem0;
716
    mem[1] = mem1;
717
    mem[2] = mem2;
718
    mem[3] = mem3;
719
    mem[4] = mem4;
720
}
721
722
static void pitch_downsample(float *x[], float *x_lp,
723
                             int len, int C)
724
{
725
    float ac[5];
726
    float tmp=Q15ONE;
727
    float lpc[4], mem[5]={0,0,0,0,0};
728
    float lpc2[5];
729
    float c1 = .8f;
730
731
    for (int i = 1; i < len >> 1; i++)
732
        x_lp[i] = .5f * (.5f * (x[0][(2*i-1)]+x[0][(2*i+1)])+x[0][2*i]);
733
    x_lp[0] = .5f * (.5f * (x[0][1])+x[0][0]);
734
    if (C==2) {
735
        for (int i = 1; i < len >> 1; i++)
736
            x_lp[i] += (.5f * (.5f * (x[1][(2*i-1)]+x[1][(2*i+1)])+x[1][2*i]));
737
        x_lp[0] += .5f * (.5f * (x[1][1])+x[1][0]);
738
    }
739
740
    celt_autocorr(x_lp, ac, NULL, 0, 4, len>>1);
741
742
    /* Noise floor -40 dB */
743
    ac[0] *= 1.0001f;
744
    /* Lag windowing */
745
    for (int i = 1; i <= 4; i++) {
746
        /*ac[i] *= exp(-.5*(2*M_PI*.002*i)*(2*M_PI*.002*i));*/
747
        ac[i] -= ac[i]*(.008f*i)*(.008f*i);
748
    }
749
750
    celt_lpc(lpc, ac, 4);
751
    for (int i = 0; i < 4; i++) {
752
        tmp = .9f * tmp;
753
        lpc[i] = (lpc[i] * tmp);
754
    }
755
    /* Add a zero */
756
    lpc2[0] = lpc[0] + .8f;
757
    lpc2[1] = lpc[1] + (c1 * lpc[0]);
758
    lpc2[2] = lpc[2] + (c1 * lpc[1]);
759
    lpc2[3] = lpc[3] + (c1 * lpc[2]);
760
    lpc2[4] = (c1 * lpc[3]);
761
    celt_fir5(x_lp, lpc2, x_lp, len>>1, mem);
762
}
763
764
static inline void dual_inner_prod(const float *x, const float *y01, const float *y02,
765
                                   int N, float *xy1, float *xy2)
766
{
767
    float xy01 = 0, xy02 = 0;
768
769
    for (int i = 0; i < N; i++) {
770
        xy01 += (x[i] * y01[i]);
771
        xy02 += (x[i] * y02[i]);
772
    }
773
774
    *xy1 = xy01;
775
    *xy2 = xy02;
776
}
777
778
static float compute_pitch_gain(float xy, float xx, float yy)
779
{
780
    return xy / sqrtf(1.f + xx * yy);
781
}
782
783
static const uint8_t second_check[16] = {0, 0, 3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2};
784
static float remove_doubling(float *x, int maxperiod, int minperiod, int N,
785
                             int *T0_, int prev_period, float prev_gain)
786
{
787
    int k, i, T, T0;
788
    float g, g0;
789
    float pg;
790
    float xy,xx,yy,xy2;
791
    float xcorr[3];
792
    float best_xy, best_yy;
793
    int offset;
794
    int minperiod0;
795
    float yy_lookup[PITCH_MAX_PERIOD+1];
796
797
    minperiod0 = minperiod;
798
    maxperiod /= 2;
799
    minperiod /= 2;
800
    *T0_ /= 2;
801
    prev_period /= 2;
802
    N /= 2;
803
    x += maxperiod;
804
    if (*T0_>=maxperiod)
805
        *T0_=maxperiod-1;
806
807
    T = T0 = *T0_;
808
    dual_inner_prod(x, x, x-T0, N, &xx, &xy);
809
    yy_lookup[0] = xx;
810
    yy=xx;
811
    for (i = 1; i <= maxperiod; i++) {
812
        yy = yy+(x[-i] * x[-i])-(x[N-i] * x[N-i]);
813
        yy_lookup[i] = FFMAX(0, yy);
814
    }
815
    yy = yy_lookup[T0];
816
    best_xy = xy;
817
    best_yy = yy;
818
    g = g0 = compute_pitch_gain(xy, xx, yy);
819
    /* Look for any pitch at T/k */
820
    for (k = 2; k <= 15; k++) {
821
        int T1, T1b;
822
        float g1;
823
        float cont=0;
824
        float thresh;
825
        T1 = (2*T0+k)/(2*k);
826
        if (T1 < minperiod)
827
            break;
828
        /* Look for another strong correlation at T1b */
829
        if (k==2)
830
        {
831
            if (T1+T0>maxperiod)
832
                T1b = T0;
833
            else
834
                T1b = T0+T1;
835
        } else
836
        {
837
            T1b = (2*second_check[k]*T0+k)/(2*k);
838
        }
839
        dual_inner_prod(x, &x[-T1], &x[-T1b], N, &xy, &xy2);
840
        xy = .5f * (xy + xy2);
841
        yy = .5f * (yy_lookup[T1] + yy_lookup[T1b]);
842
        g1 = compute_pitch_gain(xy, xx, yy);
843
        if (FFABS(T1-prev_period)<=1)
844
            cont = prev_gain;
845
        else if (FFABS(T1-prev_period)<=2 && 5 * k * k < T0)
846
            cont = prev_gain * .5f;
847
        else
848
            cont = 0;
849
        thresh = FFMAX(.3f, (.7f * g0) - cont);
850
        /* Bias against very high pitch (very short period) to avoid false-positives
851
           due to short-term correlation */
852
        if (T1<3*minperiod)
853
            thresh = FFMAX(.4f, (.85f * g0) - cont);
854
        else if (T1<2*minperiod)
855
            thresh = FFMAX(.5f, (.9f * g0) - cont);
856
        if (g1 > thresh)
857
        {
858
            best_xy = xy;
859
            best_yy = yy;
860
            T = T1;
861
            g = g1;
862
        }
863
    }
864
    best_xy = FFMAX(0, best_xy);
865
    if (best_yy <= best_xy)
866
        pg = Q15ONE;
867
    else
868
        pg = best_xy/(best_yy + 1);
869
870
    for (k = 0; k < 3; k++)
871
        xcorr[k] = celt_inner_prod(x, x-(T+k-1), N);
872
    if ((xcorr[2]-xcorr[0]) > .7f * (xcorr[1]-xcorr[0]))
873
        offset = 1;
874
    else if ((xcorr[0]-xcorr[2]) > (.7f * (xcorr[1] - xcorr[2])))
875
        offset = -1;
876
    else
877
        offset = 0;
878
    if (pg > g)
879
        pg = g;
880
    *T0_ = 2*T+offset;
881
882
    if (*T0_<minperiod0)
883
        *T0_=minperiod0;
884
    return pg;
885
}
886
887
static void find_best_pitch(float *xcorr, float *y, int len,
888
                            int max_pitch, int *best_pitch)
889
{
890
    float best_num[2];
891
    float best_den[2];
892
    float Syy = 1.f;
893
894
    best_num[0] = -1;
895
    best_num[1] = -1;
896
    best_den[0] = 0;
897
    best_den[1] = 0;
898
    best_pitch[0] = 0;
899
    best_pitch[1] = 1;
900
901
    for (int j = 0; j < len; j++)
902
        Syy += y[j] * y[j];
903
904
    for (int i = 0; i < max_pitch; i++) {
905
        if (xcorr[i]>0) {
906
            float num;
907
            float xcorr16;
908
909
            xcorr16 = xcorr[i];
910
            /* Considering the range of xcorr16, this should avoid both underflows
911
               and overflows (inf) when squaring xcorr16 */
912
            xcorr16 *= 1e-12f;
913
            num = xcorr16 * xcorr16;
914
            if ((num * best_den[1]) > (best_num[1] * Syy)) {
915
                if ((num * best_den[0]) > (best_num[0] * Syy)) {
916
                    best_num[1] = best_num[0];
917
                    best_den[1] = best_den[0];
918
                    best_pitch[1] = best_pitch[0];
919
                    best_num[0] = num;
920
                    best_den[0] = Syy;
921
                    best_pitch[0] = i;
922
                } else {
923
                    best_num[1] = num;
924
                    best_den[1] = Syy;
925
                    best_pitch[1] = i;
926
                }
927
            }
928
        }
929
        Syy += y[i+len]*y[i+len] - y[i] * y[i];
930
        Syy = FFMAX(1, Syy);
931
    }
932
}
933
934
static void pitch_search(const float *x_lp, float *y,
935
                         int len, int max_pitch, int *pitch)
936
{
937
    int lag;
938
    int best_pitch[2]={0,0};
939
    int offset;
940
941
    float x_lp4[WINDOW_SIZE];
942
    float y_lp4[WINDOW_SIZE];
943
    float xcorr[WINDOW_SIZE];
944
945
    lag = len+max_pitch;
946
947
    /* Downsample by 2 again */
948
    for (int j = 0; j < len >> 2; j++)
949
        x_lp4[j] = x_lp[2*j];
950
    for (int j = 0; j < lag >> 2; j++)
951
        y_lp4[j] = y[2*j];
952
953
    /* Coarse search with 4x decimation */
954
955
    celt_pitch_xcorr(x_lp4, y_lp4, xcorr, len>>2, max_pitch>>2);
956
957
    find_best_pitch(xcorr, y_lp4, len>>2, max_pitch>>2, best_pitch);
958
959
    /* Finer search with 2x decimation */
960
    for (int i = 0; i < max_pitch >> 1; i++) {
961
        float sum;
962
        xcorr[i] = 0;
963
        if (FFABS(i-2*best_pitch[0])>2 && FFABS(i-2*best_pitch[1])>2)
964
            continue;
965
        sum = celt_inner_prod(x_lp, y+i, len>>1);
966
        xcorr[i] = FFMAX(-1, sum);
967
    }
968
969
    find_best_pitch(xcorr, y, len>>1, max_pitch>>1, best_pitch);
970
971
    /* Refine by pseudo-interpolation */
972
    if (best_pitch[0] > 0 && best_pitch[0] < (max_pitch >> 1) - 1) {
973
        float a, b, c;
974
975
        a = xcorr[best_pitch[0] - 1];
976
        b = xcorr[best_pitch[0]];
977
        c = xcorr[best_pitch[0] + 1];
978
        if (c - a > .7f * (b - a))
979
            offset = 1;
980
        else if (a - c > .7f * (b-c))
981
            offset = -1;
982
        else
983
            offset = 0;
984
    } else {
985
        offset = 0;
986
    }
987
988
    *pitch = 2 * best_pitch[0] - offset;
989
}
990
991
static void dct(AudioRNNContext *s, float *out, const float *in)
992
{
993
    for (int i = 0; i < NB_BANDS; i++) {
994
        float sum;
995
996
        sum = s->fdsp->scalarproduct_float(in, s->dct_table[i], FFALIGN(NB_BANDS, 4));
997
        out[i] = sum * sqrtf(2.f / 22);
998
    }
999
}
1000
1001
static int compute_frame_features(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, AVComplexFloat *P,
1002
                                  float *Ex, float *Ep, float *Exp, float *features, const float *in)
1003
{
1004
    float E = 0;
1005
    float *ceps_0, *ceps_1, *ceps_2;
1006
    float spec_variability = 0;
1007
    LOCAL_ALIGNED_32(float, Ly, [NB_BANDS]);
1008
    LOCAL_ALIGNED_32(float, p, [WINDOW_SIZE]);
1009
    float pitch_buf[PITCH_BUF_SIZE>>1];
1010
    int pitch_index;
1011
    float gain;
1012
    float *(pre[1]);
1013
    float tmp[NB_BANDS];
1014
    float follow, logMax;
1015
1016
    frame_analysis(s, st, X, Ex, in);
1017
    RNN_MOVE(st->pitch_buf, &st->pitch_buf[FRAME_SIZE], PITCH_BUF_SIZE-FRAME_SIZE);
1018
    RNN_COPY(&st->pitch_buf[PITCH_BUF_SIZE-FRAME_SIZE], in, FRAME_SIZE);
1019
    pre[0] = &st->pitch_buf[0];
1020
    pitch_downsample(pre, pitch_buf, PITCH_BUF_SIZE, 1);
1021
    pitch_search(pitch_buf+(PITCH_MAX_PERIOD>>1), pitch_buf, PITCH_FRAME_SIZE,
1022
            PITCH_MAX_PERIOD-3*PITCH_MIN_PERIOD, &pitch_index);
1023
    pitch_index = PITCH_MAX_PERIOD-pitch_index;
1024
1025
    gain = remove_doubling(pitch_buf, PITCH_MAX_PERIOD, PITCH_MIN_PERIOD,
1026
            PITCH_FRAME_SIZE, &pitch_index, st->last_period, st->last_gain);
1027
    st->last_period = pitch_index;
1028
    st->last_gain = gain;
1029
1030
    for (int i = 0; i < WINDOW_SIZE; i++)
1031
        p[i] = st->pitch_buf[PITCH_BUF_SIZE-WINDOW_SIZE-pitch_index+i];
1032
1033
    s->fdsp->vector_fmul(p, p, s->window, WINDOW_SIZE);
1034
    forward_transform(st, P, p);
1035
    compute_band_energy(Ep, P);
1036
    compute_band_corr(Exp, X, P);
1037
1038
    for (int i = 0; i < NB_BANDS; i++)
1039
        Exp[i] = Exp[i] / sqrtf(.001f+Ex[i]*Ep[i]);
1040
1041
    dct(s, tmp, Exp);
1042
1043
    for (int i = 0; i < NB_DELTA_CEPS; i++)
1044
        features[NB_BANDS+2*NB_DELTA_CEPS+i] = tmp[i];
1045
1046
    features[NB_BANDS+2*NB_DELTA_CEPS] -= 1.3;
1047
    features[NB_BANDS+2*NB_DELTA_CEPS+1] -= 0.9;
1048
    features[NB_BANDS+3*NB_DELTA_CEPS] = .01*(pitch_index-300);
1049
    logMax = -2;
1050
    follow = -2;
1051
1052
    for (int i = 0; i < NB_BANDS; i++) {
1053
        Ly[i] = log10f(1e-2f + Ex[i]);
1054
        Ly[i] = FFMAX(logMax-7, FFMAX(follow-1.5, Ly[i]));
1055
        logMax = FFMAX(logMax, Ly[i]);
1056
        follow = FFMAX(follow-1.5, Ly[i]);
1057
        E += Ex[i];
1058
    }
1059
1060
    if (E < 0.04f) {
1061
        /* If there's no audio, avoid messing up the state. */
1062
        RNN_CLEAR(features, NB_FEATURES);
1063
        return 1;
1064
    }
1065
1066
    dct(s, features, Ly);
1067
    features[0] -= 12;
1068
    features[1] -= 4;
1069
    ceps_0 = st->cepstral_mem[st->memid];
1070
    ceps_1 = (st->memid < 1) ? st->cepstral_mem[CEPS_MEM+st->memid-1] : st->cepstral_mem[st->memid-1];
1071
    ceps_2 = (st->memid < 2) ? st->cepstral_mem[CEPS_MEM+st->memid-2] : st->cepstral_mem[st->memid-2];
1072
1073
    for (int i = 0; i < NB_BANDS; i++)
1074
        ceps_0[i] = features[i];
1075
1076
    st->memid++;
1077
    for (int i = 0; i < NB_DELTA_CEPS; i++) {
1078
        features[i] = ceps_0[i] + ceps_1[i] + ceps_2[i];
1079
        features[NB_BANDS+i] = ceps_0[i] - ceps_2[i];
1080
        features[NB_BANDS+NB_DELTA_CEPS+i] =  ceps_0[i] - 2*ceps_1[i] + ceps_2[i];
1081
    }
1082
    /* Spectral variability features. */
1083
    if (st->memid == CEPS_MEM)
1084
        st->memid = 0;
1085
1086
    for (int i = 0; i < CEPS_MEM; i++) {
1087
        float mindist = 1e15f;
1088
        for (int j = 0; j < CEPS_MEM; j++) {
1089
            float dist = 0.f;
1090
            for (int k = 0; k < NB_BANDS; k++) {
1091
                float tmp;
1092
1093
                tmp = st->cepstral_mem[i][k] - st->cepstral_mem[j][k];
1094
                dist += tmp*tmp;
1095
            }
1096
1097
            if (j != i)
1098
                mindist = FFMIN(mindist, dist);
1099
        }
1100
1101
        spec_variability += mindist;
1102
    }
1103
1104
    features[NB_BANDS+3*NB_DELTA_CEPS+1] = spec_variability/CEPS_MEM-2.1;
1105
1106
    return 0;
1107
}
1108
1109
static void interp_band_gain(float *g, const float *bandE)
1110
{
1111
    memset(g, 0, sizeof(*g) * FREQ_SIZE);
1112
1113
    for (int i = 0; i < NB_BANDS - 1; i++) {
1114
        const int band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
1115
1116
        for (int j = 0; j < band_size; j++) {
1117
            float frac = (float)j / band_size;
1118
1119
            g[(eband5ms[i] << FRAME_SIZE_SHIFT) + j] = (1.f - frac) * bandE[i] + frac * bandE[i + 1];
1120
        }
1121
    }
1122
}
1123
1124
static void pitch_filter(AVComplexFloat *X, const AVComplexFloat *P, const float *Ex, const float *Ep,
1125
                         const float *Exp, const float *g)
1126
{
1127
    float newE[NB_BANDS];
1128
    float r[NB_BANDS];
1129
    float norm[NB_BANDS];
1130
    float rf[FREQ_SIZE] = {0};
1131
    float normf[FREQ_SIZE]={0};
1132
1133
    for (int i = 0; i < NB_BANDS; i++) {
1134
        if (Exp[i]>g[i]) r[i] = 1;
1135
        else r[i] = SQUARE(Exp[i])*(1-SQUARE(g[i]))/(.001 + SQUARE(g[i])*(1-SQUARE(Exp[i])));
1136
        r[i]  = sqrtf(av_clipf(r[i], 0, 1));
1137
        r[i] *= sqrtf(Ex[i]/(1e-8+Ep[i]));
1138
    }
1139
    interp_band_gain(rf, r);
1140
    for (int i = 0; i < FREQ_SIZE; i++) {
1141
        X[i].re += rf[i]*P[i].re;
1142
        X[i].im += rf[i]*P[i].im;
1143
    }
1144
    compute_band_energy(newE, X);
1145
    for (int i = 0; i < NB_BANDS; i++) {
1146
        norm[i] = sqrtf(Ex[i] / (1e-8+newE[i]));
1147
    }
1148
    interp_band_gain(normf, norm);
1149
    for (int i = 0; i < FREQ_SIZE; i++) {
1150
        X[i].re *= normf[i];
1151
        X[i].im *= normf[i];
1152
    }
1153
}
1154
1155
static const float tansig_table[201] = {
1156
    0.000000f, 0.039979f, 0.079830f, 0.119427f, 0.158649f,
1157
    0.197375f, 0.235496f, 0.272905f, 0.309507f, 0.345214f,
1158
    0.379949f, 0.413644f, 0.446244f, 0.477700f, 0.507977f,
1159
    0.537050f, 0.564900f, 0.591519f, 0.616909f, 0.641077f,
1160
    0.664037f, 0.685809f, 0.706419f, 0.725897f, 0.744277f,
1161
    0.761594f, 0.777888f, 0.793199f, 0.807569f, 0.821040f,
1162
    0.833655f, 0.845456f, 0.856485f, 0.866784f, 0.876393f,
1163
    0.885352f, 0.893698f, 0.901468f, 0.908698f, 0.915420f,
1164
    0.921669f, 0.927473f, 0.932862f, 0.937863f, 0.942503f,
1165
    0.946806f, 0.950795f, 0.954492f, 0.957917f, 0.961090f,
1166
    0.964028f, 0.966747f, 0.969265f, 0.971594f, 0.973749f,
1167
    0.975743f, 0.977587f, 0.979293f, 0.980869f, 0.982327f,
1168
    0.983675f, 0.984921f, 0.986072f, 0.987136f, 0.988119f,
1169
    0.989027f, 0.989867f, 0.990642f, 0.991359f, 0.992020f,
1170
    0.992631f, 0.993196f, 0.993718f, 0.994199f, 0.994644f,
1171
    0.995055f, 0.995434f, 0.995784f, 0.996108f, 0.996407f,
1172
    0.996682f, 0.996937f, 0.997172f, 0.997389f, 0.997590f,
1173
    0.997775f, 0.997946f, 0.998104f, 0.998249f, 0.998384f,
1174
    0.998508f, 0.998623f, 0.998728f, 0.998826f, 0.998916f,
1175
    0.999000f, 0.999076f, 0.999147f, 0.999213f, 0.999273f,
1176
    0.999329f, 0.999381f, 0.999428f, 0.999472f, 0.999513f,
1177
    0.999550f, 0.999585f, 0.999617f, 0.999646f, 0.999673f,
1178
    0.999699f, 0.999722f, 0.999743f, 0.999763f, 0.999781f,
1179
    0.999798f, 0.999813f, 0.999828f, 0.999841f, 0.999853f,
1180
    0.999865f, 0.999875f, 0.999885f, 0.999893f, 0.999902f,
1181
    0.999909f, 0.999916f, 0.999923f, 0.999929f, 0.999934f,
1182
    0.999939f, 0.999944f, 0.999948f, 0.999952f, 0.999956f,
1183
    0.999959f, 0.999962f, 0.999965f, 0.999968f, 0.999970f,
1184
    0.999973f, 0.999975f, 0.999977f, 0.999978f, 0.999980f,
1185
    0.999982f, 0.999983f, 0.999984f, 0.999986f, 0.999987f,
1186
    0.999988f, 0.999989f, 0.999990f, 0.999990f, 0.999991f,
1187
    0.999992f, 0.999992f, 0.999993f, 0.999994f, 0.999994f,
1188
    0.999994f, 0.999995f, 0.999995f, 0.999996f, 0.999996f,
1189
    0.999996f, 0.999997f, 0.999997f, 0.999997f, 0.999997f,
1190
    0.999997f, 0.999998f, 0.999998f, 0.999998f, 0.999998f,
1191
    0.999998f, 0.999998f, 0.999999f, 0.999999f, 0.999999f,
1192
    0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1193
    0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1194
    1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1195
    1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1196
    1.000000f,
1197
};
1198
1199
static inline float tansig_approx(float x)
1200
{
1201
    float y, dy;
1202
    float sign=1;
1203
    int i;
1204
1205
    /* Tests are reversed to catch NaNs */
1206
    if (!(x<8))
1207
        return 1;
1208
    if (!(x>-8))
1209
        return -1;
1210
    /* Another check in case of -ffast-math */
1211
1212
    if (isnan(x))
1213
       return 0;
1214
1215
    if (x < 0) {
1216
       x=-x;
1217
       sign=-1;
1218
    }
1219
    i = (int)floor(.5f+25*x);
1220
    x -= .04f*i;
1221
    y = tansig_table[i];
1222
    dy = 1-y*y;
1223
    y = y + x*dy*(1 - y*x);
1224
    return sign*y;
1225
}
1226
1227
static inline float sigmoid_approx(float x)
1228
{
1229
    return .5f + .5f*tansig_approx(.5f*x);
1230
}
1231
1232
static void compute_dense(const DenseLayer *layer, float *output, const float *input)
1233
{
1234
    const int N = layer->nb_neurons, M = layer->nb_inputs, stride = N;
1235
1236
    for (int i = 0; i < N; i++) {
1237
        /* Compute update gate. */
1238
        float sum = layer->bias[i];
1239
1240
        for (int j = 0; j < M; j++)
1241
            sum += layer->input_weights[j * stride + i] * input[j];
1242
1243
        output[i] = WEIGHTS_SCALE * sum;
1244
    }
1245
1246
    if (layer->activation == ACTIVATION_SIGMOID) {
1247
        for (int i = 0; i < N; i++)
1248
            output[i] = sigmoid_approx(output[i]);
1249
    } else if (layer->activation == ACTIVATION_TANH) {
1250
        for (int i = 0; i < N; i++)
1251
            output[i] = tansig_approx(output[i]);
1252
    } else if (layer->activation == ACTIVATION_RELU) {
1253
        for (int i = 0; i < N; i++)
1254
            output[i] = FFMAX(0, output[i]);
1255
    } else {
1256
        av_assert0(0);
1257
    }
1258
}
1259
1260
static void compute_gru(AudioRNNContext *s, const GRULayer *gru, float *state, const float *input)
1261
{
1262
    LOCAL_ALIGNED_32(float, z, [MAX_NEURONS]);
1263
    LOCAL_ALIGNED_32(float, r, [MAX_NEURONS]);
1264
    LOCAL_ALIGNED_32(float, h, [MAX_NEURONS]);
1265
    const int M = gru->nb_inputs;
1266
    const int N = gru->nb_neurons;
1267
    const int AN = FFALIGN(N, 4);
1268
    const int AM = FFALIGN(M, 4);
1269
    const int stride = 3 * AN, istride = 3 * AM;
1270
1271
    for (int i = 0; i < N; i++) {
1272
        /* Compute update gate. */
1273
        float sum = gru->bias[i];
1274
1275
        sum += s->fdsp->scalarproduct_float(gru->input_weights + i * istride, input, AM);
1276
        sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + i * stride, state, AN);
1277
        z[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1278
    }
1279
1280
    for (int i = 0; i < N; i++) {
1281
        /* Compute reset gate. */
1282
        float sum = gru->bias[N + i];
1283
1284
        sum += s->fdsp->scalarproduct_float(gru->input_weights + AM + i * istride, input, AM);
1285
        sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + AN + i * stride, state, AN);
1286
        r[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1287
    }
1288
1289
    for (int i = 0; i < N; i++) {
1290
        /* Compute output. */
1291
        float sum = gru->bias[2 * N + i];
1292
1293
        sum += s->fdsp->scalarproduct_float(gru->input_weights + 2 * AM + i * istride, input, AM);
1294
        for (int j = 0; j < N; j++)
1295
            sum += gru->recurrent_weights[2 * AN + i * stride + j] * state[j] * r[j];
1296
1297
        if (gru->activation == ACTIVATION_SIGMOID)
1298
            sum = sigmoid_approx(WEIGHTS_SCALE * sum);
1299
        else if (gru->activation == ACTIVATION_TANH)
1300
            sum = tansig_approx(WEIGHTS_SCALE * sum);
1301
        else if (gru->activation == ACTIVATION_RELU)
1302
            sum = FFMAX(0, WEIGHTS_SCALE * sum);
1303
        else
1304
            av_assert0(0);
1305
        h[i] = z[i] * state[i] + (1.f - z[i]) * sum;
1306
    }
1307
1308
    RNN_COPY(state, h, N);
1309
}
1310
1311
#define INPUT_SIZE 42
1312
1313
static void compute_rnn(AudioRNNContext *s, RNNState *rnn, float *gains, float *vad, const float *input)
1314
{
1315
    LOCAL_ALIGNED_32(float, dense_out,     [MAX_NEURONS]);
1316
    LOCAL_ALIGNED_32(float, noise_input,   [MAX_NEURONS * 3]);
1317
    LOCAL_ALIGNED_32(float, denoise_input, [MAX_NEURONS * 3]);
1318
1319
    compute_dense(rnn->model->input_dense, dense_out, input);
1320
    compute_gru(s, rnn->model->vad_gru, rnn->vad_gru_state, dense_out);
1321
    compute_dense(rnn->model->vad_output, vad, rnn->vad_gru_state);
1322
1323
    memcpy(noise_input, dense_out, rnn->model->input_dense_size * sizeof(float));
1324
    memcpy(noise_input + rnn->model->input_dense_size,
1325
           rnn->vad_gru_state, rnn->model->vad_gru_size * sizeof(float));
1326
    memcpy(noise_input + rnn->model->input_dense_size + rnn->model->vad_gru_size,
1327
           input, INPUT_SIZE * sizeof(float));
1328
1329
    compute_gru(s, rnn->model->noise_gru, rnn->noise_gru_state, noise_input);
1330
1331
    memcpy(denoise_input, rnn->vad_gru_state, rnn->model->vad_gru_size * sizeof(float));
1332
    memcpy(denoise_input + rnn->model->vad_gru_size,
1333
           rnn->noise_gru_state, rnn->model->noise_gru_size * sizeof(float));
1334
    memcpy(denoise_input + rnn->model->vad_gru_size + rnn->model->noise_gru_size,
1335
           input, INPUT_SIZE * sizeof(float));
1336
1337
    compute_gru(s, rnn->model->denoise_gru, rnn->denoise_gru_state, denoise_input);
1338
    compute_dense(rnn->model->denoise_output, gains, rnn->denoise_gru_state);
1339
}
1340
1341
static float rnnoise_channel(AudioRNNContext *s, DenoiseState *st, float *out, const float *in)
1342
{
1343
    AVComplexFloat X[FREQ_SIZE];
1344
    AVComplexFloat P[WINDOW_SIZE];
1345
    float x[FRAME_SIZE];
1346
    float Ex[NB_BANDS], Ep[NB_BANDS];
1347
    LOCAL_ALIGNED_32(float, Exp, [NB_BANDS]);
1348
    float features[NB_FEATURES];
1349
    float g[NB_BANDS];
1350
    float gf[FREQ_SIZE];
1351
    float vad_prob = 0;
1352
    static const float a_hp[2] = {-1.99599, 0.99600};
1353
    static const float b_hp[2] = {-2, 1};
1354
    int silence;
1355
1356
    biquad(x, st->mem_hp_x, in, b_hp, a_hp, FRAME_SIZE);
1357
    silence = compute_frame_features(s, st, X, P, Ex, Ep, Exp, features, x);
1358
1359
    if (!silence) {
1360
        compute_rnn(s, &st->rnn, g, &vad_prob, features);
1361
        pitch_filter(X, P, Ex, Ep, Exp, g);
1362
        for (int i = 0; i < NB_BANDS; i++) {
1363
            float alpha = .6f;
1364
1365
            g[i] = FFMAX(g[i], alpha * st->lastg[i]);
1366
            st->lastg[i] = g[i];
1367
        }
1368
1369
        interp_band_gain(gf, g);
1370
1371
        for (int i = 0; i < FREQ_SIZE; i++) {
1372
            X[i].re *= gf[i];
1373
            X[i].im *= gf[i];
1374
        }
1375
    }
1376
1377
    frame_synthesis(s, st, out, X);
1378
1379
    return vad_prob;
1380
}
1381
1382
typedef struct ThreadData {
1383
    AVFrame *in, *out;
1384
} ThreadData;
1385
1386
static int rnnoise_channels(AVFilterContext *ctx, void *arg, int jobnr, int nb_jobs)
1387
{
1388
    AudioRNNContext *s = ctx->priv;
1389
    ThreadData *td = arg;
1390
    AVFrame *in = td->in;
1391
    AVFrame *out = td->out;
1392
    const int start = (out->channels * jobnr) / nb_jobs;
1393
    const int end = (out->channels * (jobnr+1)) / nb_jobs;
1394
1395
    for (int ch = start; ch < end; ch++) {
1396
        rnnoise_channel(s, &s->st[ch],
1397
                        (float *)out->extended_data[ch],
1398
                        (const float *)in->extended_data[ch]);
1399
    }
1400
1401
    return 0;
1402
}
1403
1404
static int filter_frame(AVFilterLink *inlink, AVFrame *in)
1405
{
1406
    AVFilterContext *ctx = inlink->dst;
1407
    AVFilterLink *outlink = ctx->outputs[0];
1408
    AVFrame *out = NULL;
1409
    ThreadData td;
1410
1411
    out = ff_get_audio_buffer(outlink, FRAME_SIZE);
1412
    if (!out) {
1413
        av_frame_free(&in);
1414
        return AVERROR(ENOMEM);
1415
    }
1416
    out->pts = in->pts;
1417
1418
    td.in = in; td.out = out;
1419
    ctx->internal->execute(ctx, rnnoise_channels, &td, NULL, FFMIN(outlink->channels,
1420
                                                                   ff_filter_get_nb_threads(ctx)));
1421
1422
    av_frame_free(&in);
1423
    return ff_filter_frame(outlink, out);
1424
}
1425
1426
static int activate(AVFilterContext *ctx)
1427
{
1428
    AVFilterLink *inlink = ctx->inputs[0];
1429
    AVFilterLink *outlink = ctx->outputs[0];
1430
    AVFrame *in = NULL;
1431
    int ret;
1432
1433
    FF_FILTER_FORWARD_STATUS_BACK(outlink, inlink);
1434
1435
    ret = ff_inlink_consume_samples(inlink, FRAME_SIZE, FRAME_SIZE, &in);
1436
    if (ret < 0)
1437
        return ret;
1438
1439
    if (ret > 0)
1440
        return filter_frame(inlink, in);
1441
1442
    FF_FILTER_FORWARD_STATUS(inlink, outlink);
1443
    FF_FILTER_FORWARD_WANTED(outlink, inlink);
1444
1445
    return FFERROR_NOT_READY;
1446
}
1447
1448
static av_cold int init(AVFilterContext *ctx)
1449
{
1450
    AudioRNNContext *s = ctx->priv;
1451
    FILE *f;
1452
1453
    s->fdsp = avpriv_float_dsp_alloc(0);
1454
    if (!s->fdsp)
1455
        return AVERROR(ENOMEM);
1456
1457
    if (!s->model_name)
1458
        return AVERROR(EINVAL);
1459
    f = av_fopen_utf8(s->model_name, "r");
1460
    if (!f)
1461
        return AVERROR(EINVAL);
1462
1463
    s->model = rnnoise_model_from_file(f);
1464
    fclose(f);
1465
    if (!s->model)
1466
        return AVERROR(EINVAL);
1467
1468
    for (int i = 0; i < FRAME_SIZE; i++) {
1469
        s->window[i] = sin(.5*M_PI*sin(.5*M_PI*(i+.5)/FRAME_SIZE) * sin(.5*M_PI*(i+.5)/FRAME_SIZE));
1470
        s->window[WINDOW_SIZE - 1 - i] = s->window[i];
1471
    }
1472
1473
    for (int i = 0; i < NB_BANDS; i++) {
1474
        for (int j = 0; j < NB_BANDS; j++) {
1475
            s->dct_table[j][i] = cosf((i + .5f) * j * M_PI / NB_BANDS);
1476
            if (j == 0)
1477
                s->dct_table[j][i] *= sqrtf(.5);
1478
        }
1479
    }
1480
1481
    return 0;
1482
}
1483
1484
static av_cold void uninit(AVFilterContext *ctx)
1485
{
1486
    AudioRNNContext *s = ctx->priv;
1487
1488
    av_freep(&s->fdsp);
1489
    rnnoise_model_free(s->model);
1490
    s->model = NULL;
1491
1492
    if (s->st) {
1493
        for (int ch = 0; ch < s->channels; ch++) {
1494
            av_freep(&s->st[ch].rnn.vad_gru_state);
1495
            av_freep(&s->st[ch].rnn.noise_gru_state);
1496
            av_freep(&s->st[ch].rnn.denoise_gru_state);
1497
            av_tx_uninit(&s->st[ch].tx);
1498
            av_tx_uninit(&s->st[ch].txi);
1499
        }
1500
    }
1501
    av_freep(&s->st);
1502
}
1503
1504
static const AVFilterPad inputs[] = {
1505
    {
1506
        .name         = "default",
1507
        .type         = AVMEDIA_TYPE_AUDIO,
1508
        .config_props = config_input,
1509
    },
1510
    { NULL }
1511
};
1512
1513
static const AVFilterPad outputs[] = {
1514
    {
1515
        .name          = "default",
1516
        .type          = AVMEDIA_TYPE_AUDIO,
1517
    },
1518
    { NULL }
1519
};
1520
1521
#define OFFSET(x) offsetof(AudioRNNContext, x)
1522
#define AF AV_OPT_FLAG_AUDIO_PARAM|AV_OPT_FLAG_FILTERING_PARAM
1523
1524
static const AVOption arnndn_options[] = {
1525
    { "model", "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1526
    { "m",     "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1527
    { NULL }
1528
};
1529
1530
AVFILTER_DEFINE_CLASS(arnndn);
1531
1532
AVFilter ff_af_arnndn = {
1533
    .name          = "arnndn",
1534
    .description   = NULL_IF_CONFIG_SMALL("Reduce noise from speech using Recurrent Neural Networks."),
1535
    .query_formats = query_formats,
1536
    .priv_size     = sizeof(AudioRNNContext),
1537
    .priv_class    = &arnndn_class,
1538
    .activate      = activate,
1539
    .init          = init,
1540
    .uninit        = uninit,
1541
    .inputs        = inputs,
1542
    .outputs       = outputs,
1543
    .flags         = AVFILTER_FLAG_SLICE_THREADS,
1544
};