FFmpeg coverage


Directory: ../../../ffmpeg/
File: src/libavfilter/af_arnndn.c
Date: 2024-04-19 07:31:02
Exec Total Coverage
Lines: 0 762 0.0%
Functions: 0 41 0.0%
Branches: 0 700 0.0%

Line Branch Exec Source
1 /*
2 * Copyright (c) 2018 Gregor Richards
3 * Copyright (c) 2017 Mozilla
4 * Copyright (c) 2005-2009 Xiph.Org Foundation
5 * Copyright (c) 2007-2008 CSIRO
6 * Copyright (c) 2008-2011 Octasic Inc.
7 * Copyright (c) Jean-Marc Valin
8 * Copyright (c) 2019 Paul B Mahol
9 *
10 * Redistribution and use in source and binary forms, with or without
11 * modification, are permitted provided that the following conditions
12 * are met:
13 *
14 * - Redistributions of source code must retain the above copyright
15 * notice, this list of conditions and the following disclaimer.
16 *
17 * - Redistributions in binary form must reproduce the above copyright
18 * notice, this list of conditions and the following disclaimer in the
19 * documentation and/or other materials provided with the distribution.
20 *
21 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22 * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR
25 * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26 * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28 * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29 * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30 * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 */
33
34 #include "libavutil/avassert.h"
35 #include "libavutil/file_open.h"
36 #include "libavutil/float_dsp.h"
37 #include "libavutil/mem.h"
38 #include "libavutil/mem_internal.h"
39 #include "libavutil/opt.h"
40 #include "libavutil/tx.h"
41 #include "avfilter.h"
42 #include "audio.h"
43 #include "filters.h"
44 #include "formats.h"
45
46 #define FRAME_SIZE_SHIFT 2
47 #define FRAME_SIZE (120<<FRAME_SIZE_SHIFT)
48 #define WINDOW_SIZE (2*FRAME_SIZE)
49 #define FREQ_SIZE (FRAME_SIZE + 1)
50
51 #define PITCH_MIN_PERIOD 60
52 #define PITCH_MAX_PERIOD 768
53 #define PITCH_FRAME_SIZE 960
54 #define PITCH_BUF_SIZE (PITCH_MAX_PERIOD+PITCH_FRAME_SIZE)
55
56 #define SQUARE(x) ((x)*(x))
57
58 #define NB_BANDS 22
59
60 #define CEPS_MEM 8
61 #define NB_DELTA_CEPS 6
62
63 #define NB_FEATURES (NB_BANDS+3*NB_DELTA_CEPS+2)
64
65 #define WEIGHTS_SCALE (1.f/256)
66
67 #define MAX_NEURONS 128
68
69 #define ACTIVATION_TANH 0
70 #define ACTIVATION_SIGMOID 1
71 #define ACTIVATION_RELU 2
72
73 #define Q15ONE 1.0f
74
75 typedef struct DenseLayer {
76 const float *bias;
77 const float *input_weights;
78 int nb_inputs;
79 int nb_neurons;
80 int activation;
81 } DenseLayer;
82
83 typedef struct GRULayer {
84 const float *bias;
85 const float *input_weights;
86 const float *recurrent_weights;
87 int nb_inputs;
88 int nb_neurons;
89 int activation;
90 } GRULayer;
91
92 typedef struct RNNModel {
93 int input_dense_size;
94 const DenseLayer *input_dense;
95
96 int vad_gru_size;
97 const GRULayer *vad_gru;
98
99 int noise_gru_size;
100 const GRULayer *noise_gru;
101
102 int denoise_gru_size;
103 const GRULayer *denoise_gru;
104
105 int denoise_output_size;
106 const DenseLayer *denoise_output;
107
108 int vad_output_size;
109 const DenseLayer *vad_output;
110 } RNNModel;
111
112 typedef struct RNNState {
113 float *vad_gru_state;
114 float *noise_gru_state;
115 float *denoise_gru_state;
116 RNNModel *model;
117 } RNNState;
118
119 typedef struct DenoiseState {
120 float analysis_mem[FRAME_SIZE];
121 float cepstral_mem[CEPS_MEM][NB_BANDS];
122 int memid;
123 DECLARE_ALIGNED(32, float, synthesis_mem)[FRAME_SIZE];
124 float pitch_buf[PITCH_BUF_SIZE];
125 float pitch_enh_buf[PITCH_BUF_SIZE];
126 float last_gain;
127 int last_period;
128 float mem_hp_x[2];
129 float lastg[NB_BANDS];
130 float history[FRAME_SIZE];
131 RNNState rnn[2];
132 AVTXContext *tx, *txi;
133 av_tx_fn tx_fn, txi_fn;
134 } DenoiseState;
135
136 typedef struct AudioRNNContext {
137 const AVClass *class;
138
139 char *model_name;
140 float mix;
141
142 int channels;
143 DenoiseState *st;
144
145 DECLARE_ALIGNED(32, float, window)[WINDOW_SIZE];
146 DECLARE_ALIGNED(32, float, dct_table)[FFALIGN(NB_BANDS, 4)][FFALIGN(NB_BANDS, 4)];
147
148 RNNModel *model[2];
149
150 AVFloatDSPContext *fdsp;
151 } AudioRNNContext;
152
153 #define F_ACTIVATION_TANH 0
154 #define F_ACTIVATION_SIGMOID 1
155 #define F_ACTIVATION_RELU 2
156
157 static void rnnoise_model_free(RNNModel *model)
158 {
159 #define FREE_MAYBE(ptr) do { if (ptr) free(ptr); } while (0)
160 #define FREE_DENSE(name) do { \
161 if (model->name) { \
162 av_free((void *) model->name->input_weights); \
163 av_free((void *) model->name->bias); \
164 av_free((void *) model->name); \
165 } \
166 } while (0)
167 #define FREE_GRU(name) do { \
168 if (model->name) { \
169 av_free((void *) model->name->input_weights); \
170 av_free((void *) model->name->recurrent_weights); \
171 av_free((void *) model->name->bias); \
172 av_free((void *) model->name); \
173 } \
174 } while (0)
175
176 if (!model)
177 return;
178 FREE_DENSE(input_dense);
179 FREE_GRU(vad_gru);
180 FREE_GRU(noise_gru);
181 FREE_GRU(denoise_gru);
182 FREE_DENSE(denoise_output);
183 FREE_DENSE(vad_output);
184 av_free(model);
185 }
186
187 static int rnnoise_model_from_file(FILE *f, RNNModel **rnn)
188 {
189 RNNModel *ret = NULL;
190 DenseLayer *input_dense;
191 GRULayer *vad_gru;
192 GRULayer *noise_gru;
193 GRULayer *denoise_gru;
194 DenseLayer *denoise_output;
195 DenseLayer *vad_output;
196 int in;
197
198 if (fscanf(f, "rnnoise-nu model file version %d\n", &in) != 1 || in != 1)
199 return AVERROR_INVALIDDATA;
200
201 ret = av_calloc(1, sizeof(RNNModel));
202 if (!ret)
203 return AVERROR(ENOMEM);
204
205 #define ALLOC_LAYER(type, name) \
206 name = av_calloc(1, sizeof(type)); \
207 if (!name) { \
208 rnnoise_model_free(ret); \
209 return AVERROR(ENOMEM); \
210 } \
211 ret->name = name
212
213 ALLOC_LAYER(DenseLayer, input_dense);
214 ALLOC_LAYER(GRULayer, vad_gru);
215 ALLOC_LAYER(GRULayer, noise_gru);
216 ALLOC_LAYER(GRULayer, denoise_gru);
217 ALLOC_LAYER(DenseLayer, denoise_output);
218 ALLOC_LAYER(DenseLayer, vad_output);
219
220 #define INPUT_VAL(name) do { \
221 if (fscanf(f, "%d", &in) != 1 || in < 0 || in > 128) { \
222 rnnoise_model_free(ret); \
223 return AVERROR(EINVAL); \
224 } \
225 name = in; \
226 } while (0)
227
228 #define INPUT_ACTIVATION(name) do { \
229 int activation; \
230 INPUT_VAL(activation); \
231 switch (activation) { \
232 case F_ACTIVATION_SIGMOID: \
233 name = ACTIVATION_SIGMOID; \
234 break; \
235 case F_ACTIVATION_RELU: \
236 name = ACTIVATION_RELU; \
237 break; \
238 default: \
239 name = ACTIVATION_TANH; \
240 } \
241 } while (0)
242
243 #define INPUT_ARRAY(name, len) do { \
244 float *values = av_calloc((len), sizeof(float)); \
245 if (!values) { \
246 rnnoise_model_free(ret); \
247 return AVERROR(ENOMEM); \
248 } \
249 name = values; \
250 for (int i = 0; i < (len); i++) { \
251 if (fscanf(f, "%d", &in) != 1) { \
252 rnnoise_model_free(ret); \
253 return AVERROR(EINVAL); \
254 } \
255 values[i] = in; \
256 } \
257 } while (0)
258
259 #define INPUT_ARRAY3(name, len0, len1, len2) do { \
260 float *values = av_calloc(FFALIGN((len0), 4) * FFALIGN((len1), 4) * (len2), sizeof(float)); \
261 if (!values) { \
262 rnnoise_model_free(ret); \
263 return AVERROR(ENOMEM); \
264 } \
265 name = values; \
266 for (int k = 0; k < (len0); k++) { \
267 for (int i = 0; i < (len2); i++) { \
268 for (int j = 0; j < (len1); j++) { \
269 if (fscanf(f, "%d", &in) != 1) { \
270 rnnoise_model_free(ret); \
271 return AVERROR(EINVAL); \
272 } \
273 values[j * (len2) * FFALIGN((len0), 4) + i * FFALIGN((len0), 4) + k] = in; \
274 } \
275 } \
276 } \
277 } while (0)
278
279 #define NEW_LINE() do { \
280 int c; \
281 while ((c = fgetc(f)) != EOF) { \
282 if (c == '\n') \
283 break; \
284 } \
285 } while (0)
286
287 #define INPUT_DENSE(name) do { \
288 INPUT_VAL(name->nb_inputs); \
289 INPUT_VAL(name->nb_neurons); \
290 ret->name ## _size = name->nb_neurons; \
291 INPUT_ACTIVATION(name->activation); \
292 NEW_LINE(); \
293 INPUT_ARRAY(name->input_weights, name->nb_inputs * name->nb_neurons); \
294 NEW_LINE(); \
295 INPUT_ARRAY(name->bias, name->nb_neurons); \
296 NEW_LINE(); \
297 } while (0)
298
299 #define INPUT_GRU(name) do { \
300 INPUT_VAL(name->nb_inputs); \
301 INPUT_VAL(name->nb_neurons); \
302 ret->name ## _size = name->nb_neurons; \
303 INPUT_ACTIVATION(name->activation); \
304 NEW_LINE(); \
305 INPUT_ARRAY3(name->input_weights, name->nb_inputs, name->nb_neurons, 3); \
306 NEW_LINE(); \
307 INPUT_ARRAY3(name->recurrent_weights, name->nb_neurons, name->nb_neurons, 3); \
308 NEW_LINE(); \
309 INPUT_ARRAY(name->bias, name->nb_neurons * 3); \
310 NEW_LINE(); \
311 } while (0)
312
313 INPUT_DENSE(input_dense);
314 INPUT_GRU(vad_gru);
315 INPUT_GRU(noise_gru);
316 INPUT_GRU(denoise_gru);
317 INPUT_DENSE(denoise_output);
318 INPUT_DENSE(vad_output);
319
320 if (vad_output->nb_neurons != 1) {
321 rnnoise_model_free(ret);
322 return AVERROR(EINVAL);
323 }
324
325 *rnn = ret;
326
327 return 0;
328 }
329
330 static int query_formats(AVFilterContext *ctx)
331 {
332 static const enum AVSampleFormat sample_fmts[] = {
333 AV_SAMPLE_FMT_FLTP,
334 AV_SAMPLE_FMT_NONE
335 };
336 int ret, sample_rates[] = { 48000, -1 };
337
338 ret = ff_set_common_formats_from_list(ctx, sample_fmts);
339 if (ret < 0)
340 return ret;
341
342 ret = ff_set_common_all_channel_counts(ctx);
343 if (ret < 0)
344 return ret;
345
346 return ff_set_common_samplerates_from_list(ctx, sample_rates);
347 }
348
349 static int config_input(AVFilterLink *inlink)
350 {
351 AVFilterContext *ctx = inlink->dst;
352 AudioRNNContext *s = ctx->priv;
353 int ret = 0;
354
355 s->channels = inlink->ch_layout.nb_channels;
356
357 if (!s->st)
358 s->st = av_calloc(s->channels, sizeof(DenoiseState));
359 if (!s->st)
360 return AVERROR(ENOMEM);
361
362 for (int i = 0; i < s->channels; i++) {
363 DenoiseState *st = &s->st[i];
364
365 st->rnn[0].model = s->model[0];
366 st->rnn[0].vad_gru_state = av_calloc(sizeof(float), FFALIGN(s->model[0]->vad_gru_size, 16));
367 st->rnn[0].noise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model[0]->noise_gru_size, 16));
368 st->rnn[0].denoise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model[0]->denoise_gru_size, 16));
369 if (!st->rnn[0].vad_gru_state ||
370 !st->rnn[0].noise_gru_state ||
371 !st->rnn[0].denoise_gru_state)
372 return AVERROR(ENOMEM);
373 }
374
375 for (int i = 0; i < s->channels; i++) {
376 DenoiseState *st = &s->st[i];
377 float scale = 1.f;
378
379 if (!st->tx)
380 ret = av_tx_init(&st->tx, &st->tx_fn, AV_TX_FLOAT_FFT, 0, WINDOW_SIZE, &scale, 0);
381 if (ret < 0)
382 return ret;
383
384 if (!st->txi)
385 ret = av_tx_init(&st->txi, &st->txi_fn, AV_TX_FLOAT_FFT, 1, WINDOW_SIZE, &scale, 0);
386 if (ret < 0)
387 return ret;
388 }
389
390 return ret;
391 }
392
393 static void biquad(float *y, float mem[2], const float *x,
394 const float *b, const float *a, int N)
395 {
396 for (int i = 0; i < N; i++) {
397 float xi, yi;
398
399 xi = x[i];
400 yi = x[i] + mem[0];
401 mem[0] = mem[1] + (b[0]*xi - a[0]*yi);
402 mem[1] = (b[1]*xi - a[1]*yi);
403 y[i] = yi;
404 }
405 }
406
407 #define RNN_MOVE(dst, src, n) (memmove((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
408 #define RNN_CLEAR(dst, n) (memset((dst), 0, (n)*sizeof(*(dst))))
409 #define RNN_COPY(dst, src, n) (memcpy((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
410
411 static void forward_transform(DenoiseState *st, AVComplexFloat *out, const float *in)
412 {
413 AVComplexFloat x[WINDOW_SIZE];
414 AVComplexFloat y[WINDOW_SIZE];
415
416 for (int i = 0; i < WINDOW_SIZE; i++) {
417 x[i].re = in[i];
418 x[i].im = 0;
419 }
420
421 st->tx_fn(st->tx, y, x, sizeof(AVComplexFloat));
422
423 RNN_COPY(out, y, FREQ_SIZE);
424 }
425
426 static void inverse_transform(DenoiseState *st, float *out, const AVComplexFloat *in)
427 {
428 AVComplexFloat x[WINDOW_SIZE];
429 AVComplexFloat y[WINDOW_SIZE];
430
431 RNN_COPY(x, in, FREQ_SIZE);
432
433 for (int i = FREQ_SIZE; i < WINDOW_SIZE; i++) {
434 x[i].re = x[WINDOW_SIZE - i].re;
435 x[i].im = -x[WINDOW_SIZE - i].im;
436 }
437
438 st->txi_fn(st->txi, y, x, sizeof(AVComplexFloat));
439
440 for (int i = 0; i < WINDOW_SIZE; i++)
441 out[i] = y[i].re / WINDOW_SIZE;
442 }
443
444 static const uint8_t eband5ms[] = {
445 /*0 200 400 600 800 1k 1.2 1.4 1.6 2k 2.4 2.8 3.2 4k 4.8 5.6 6.8 8k 9.6 12k 15.6 20k*/
446 0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 34, 40, 48, 60, 78, 100
447 };
448
449 static void compute_band_energy(float *bandE, const AVComplexFloat *X)
450 {
451 float sum[NB_BANDS] = {0};
452
453 for (int i = 0; i < NB_BANDS - 1; i++) {
454 int band_size;
455
456 band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
457 for (int j = 0; j < band_size; j++) {
458 float tmp, frac = (float)j / band_size;
459
460 tmp = SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].re);
461 tmp += SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].im);
462 sum[i] += (1.f - frac) * tmp;
463 sum[i + 1] += frac * tmp;
464 }
465 }
466
467 sum[0] *= 2;
468 sum[NB_BANDS - 1] *= 2;
469
470 for (int i = 0; i < NB_BANDS; i++)
471 bandE[i] = sum[i];
472 }
473
474 static void compute_band_corr(float *bandE, const AVComplexFloat *X, const AVComplexFloat *P)
475 {
476 float sum[NB_BANDS] = { 0 };
477
478 for (int i = 0; i < NB_BANDS - 1; i++) {
479 int band_size;
480
481 band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
482 for (int j = 0; j < band_size; j++) {
483 float tmp, frac = (float)j / band_size;
484
485 tmp = X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re;
486 tmp += X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im;
487 sum[i] += (1 - frac) * tmp;
488 sum[i + 1] += frac * tmp;
489 }
490 }
491
492 sum[0] *= 2;
493 sum[NB_BANDS-1] *= 2;
494
495 for (int i = 0; i < NB_BANDS; i++)
496 bandE[i] = sum[i];
497 }
498
499 static void frame_analysis(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, float *Ex, const float *in)
500 {
501 LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
502
503 RNN_COPY(x, st->analysis_mem, FRAME_SIZE);
504 RNN_COPY(x + FRAME_SIZE, in, FRAME_SIZE);
505 RNN_COPY(st->analysis_mem, in, FRAME_SIZE);
506 s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
507 forward_transform(st, X, x);
508 compute_band_energy(Ex, X);
509 }
510
511 static void frame_synthesis(AudioRNNContext *s, DenoiseState *st, float *out, const AVComplexFloat *y)
512 {
513 LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
514 const float *src = st->history;
515 const float mix = s->mix;
516 const float imix = 1.f - FFMAX(mix, 0.f);
517
518 inverse_transform(st, x, y);
519 s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
520 s->fdsp->vector_fmac_scalar(x, st->synthesis_mem, 1.f, FRAME_SIZE);
521 RNN_COPY(out, x, FRAME_SIZE);
522 RNN_COPY(st->synthesis_mem, &x[FRAME_SIZE], FRAME_SIZE);
523
524 for (int n = 0; n < FRAME_SIZE; n++)
525 out[n] = out[n] * mix + src[n] * imix;
526 }
527
528 static inline void xcorr_kernel(const float *x, const float *y, float sum[4], int len)
529 {
530 float y_0, y_1, y_2, y_3 = 0;
531 int j;
532
533 y_0 = *y++;
534 y_1 = *y++;
535 y_2 = *y++;
536
537 for (j = 0; j < len - 3; j += 4) {
538 float tmp;
539
540 tmp = *x++;
541 y_3 = *y++;
542 sum[0] += tmp * y_0;
543 sum[1] += tmp * y_1;
544 sum[2] += tmp * y_2;
545 sum[3] += tmp * y_3;
546 tmp = *x++;
547 y_0 = *y++;
548 sum[0] += tmp * y_1;
549 sum[1] += tmp * y_2;
550 sum[2] += tmp * y_3;
551 sum[3] += tmp * y_0;
552 tmp = *x++;
553 y_1 = *y++;
554 sum[0] += tmp * y_2;
555 sum[1] += tmp * y_3;
556 sum[2] += tmp * y_0;
557 sum[3] += tmp * y_1;
558 tmp = *x++;
559 y_2 = *y++;
560 sum[0] += tmp * y_3;
561 sum[1] += tmp * y_0;
562 sum[2] += tmp * y_1;
563 sum[3] += tmp * y_2;
564 }
565
566 if (j++ < len) {
567 float tmp = *x++;
568
569 y_3 = *y++;
570 sum[0] += tmp * y_0;
571 sum[1] += tmp * y_1;
572 sum[2] += tmp * y_2;
573 sum[3] += tmp * y_3;
574 }
575
576 if (j++ < len) {
577 float tmp=*x++;
578
579 y_0 = *y++;
580 sum[0] += tmp * y_1;
581 sum[1] += tmp * y_2;
582 sum[2] += tmp * y_3;
583 sum[3] += tmp * y_0;
584 }
585
586 if (j < len) {
587 float tmp=*x++;
588
589 y_1 = *y++;
590 sum[0] += tmp * y_2;
591 sum[1] += tmp * y_3;
592 sum[2] += tmp * y_0;
593 sum[3] += tmp * y_1;
594 }
595 }
596
597 static inline float celt_inner_prod(const float *x,
598 const float *y, int N)
599 {
600 float xy = 0.f;
601
602 for (int i = 0; i < N; i++)
603 xy += x[i] * y[i];
604
605 return xy;
606 }
607
608 static void celt_pitch_xcorr(const float *x, const float *y,
609 float *xcorr, int len, int max_pitch)
610 {
611 int i;
612
613 for (i = 0; i < max_pitch - 3; i += 4) {
614 float sum[4] = { 0, 0, 0, 0};
615
616 xcorr_kernel(x, y + i, sum, len);
617
618 xcorr[i] = sum[0];
619 xcorr[i + 1] = sum[1];
620 xcorr[i + 2] = sum[2];
621 xcorr[i + 3] = sum[3];
622 }
623 /* In case max_pitch isn't a multiple of 4, do non-unrolled version. */
624 for (; i < max_pitch; i++) {
625 xcorr[i] = celt_inner_prod(x, y + i, len);
626 }
627 }
628
629 static int celt_autocorr(const float *x, /* in: [0...n-1] samples x */
630 float *ac, /* out: [0...lag-1] ac values */
631 const float *window,
632 int overlap,
633 int lag,
634 int n)
635 {
636 int fastN = n - lag;
637 int shift;
638 const float *xptr;
639 float xx[PITCH_BUF_SIZE>>1];
640
641 if (overlap == 0) {
642 xptr = x;
643 } else {
644 for (int i = 0; i < n; i++)
645 xx[i] = x[i];
646 for (int i = 0; i < overlap; i++) {
647 xx[i] = x[i] * window[i];
648 xx[n-i-1] = x[n-i-1] * window[i];
649 }
650 xptr = xx;
651 }
652
653 shift = 0;
654 celt_pitch_xcorr(xptr, xptr, ac, fastN, lag+1);
655
656 for (int k = 0; k <= lag; k++) {
657 float d = 0.f;
658
659 for (int i = k + fastN; i < n; i++)
660 d += xptr[i] * xptr[i-k];
661 ac[k] += d;
662 }
663
664 return shift;
665 }
666
667 static void celt_lpc(float *lpc, /* out: [0...p-1] LPC coefficients */
668 const float *ac, /* in: [0...p] autocorrelation values */
669 int p)
670 {
671 float r, error = ac[0];
672
673 RNN_CLEAR(lpc, p);
674 if (ac[0] != 0) {
675 for (int i = 0; i < p; i++) {
676 /* Sum up this iteration's reflection coefficient */
677 float rr = 0;
678 for (int j = 0; j < i; j++)
679 rr += (lpc[j] * ac[i - j]);
680 rr += ac[i + 1];
681 r = -rr/error;
682 /* Update LPC coefficients and total error */
683 lpc[i] = r;
684 for (int j = 0; j < (i + 1) >> 1; j++) {
685 float tmp1, tmp2;
686 tmp1 = lpc[j];
687 tmp2 = lpc[i-1-j];
688 lpc[j] = tmp1 + (r*tmp2);
689 lpc[i-1-j] = tmp2 + (r*tmp1);
690 }
691
692 error = error - (r * r *error);
693 /* Bail out once we get 30 dB gain */
694 if (error < .001f * ac[0])
695 break;
696 }
697 }
698 }
699
700 static void celt_fir5(const float *x,
701 const float *num,
702 float *y,
703 int N,
704 float *mem)
705 {
706 float num0, num1, num2, num3, num4;
707 float mem0, mem1, mem2, mem3, mem4;
708
709 num0 = num[0];
710 num1 = num[1];
711 num2 = num[2];
712 num3 = num[3];
713 num4 = num[4];
714 mem0 = mem[0];
715 mem1 = mem[1];
716 mem2 = mem[2];
717 mem3 = mem[3];
718 mem4 = mem[4];
719
720 for (int i = 0; i < N; i++) {
721 float sum = x[i];
722
723 sum += (num0*mem0);
724 sum += (num1*mem1);
725 sum += (num2*mem2);
726 sum += (num3*mem3);
727 sum += (num4*mem4);
728 mem4 = mem3;
729 mem3 = mem2;
730 mem2 = mem1;
731 mem1 = mem0;
732 mem0 = x[i];
733 y[i] = sum;
734 }
735
736 mem[0] = mem0;
737 mem[1] = mem1;
738 mem[2] = mem2;
739 mem[3] = mem3;
740 mem[4] = mem4;
741 }
742
743 static void pitch_downsample(float *x[], float *x_lp,
744 int len, int C)
745 {
746 float ac[5];
747 float tmp=Q15ONE;
748 float lpc[4], mem[5]={0,0,0,0,0};
749 float lpc2[5];
750 float c1 = .8f;
751
752 for (int i = 1; i < len >> 1; i++)
753 x_lp[i] = .5f * (.5f * (x[0][(2*i-1)]+x[0][(2*i+1)])+x[0][2*i]);
754 x_lp[0] = .5f * (.5f * (x[0][1])+x[0][0]);
755 if (C==2) {
756 for (int i = 1; i < len >> 1; i++)
757 x_lp[i] += (.5f * (.5f * (x[1][(2*i-1)]+x[1][(2*i+1)])+x[1][2*i]));
758 x_lp[0] += .5f * (.5f * (x[1][1])+x[1][0]);
759 }
760
761 celt_autocorr(x_lp, ac, NULL, 0, 4, len>>1);
762
763 /* Noise floor -40 dB */
764 ac[0] *= 1.0001f;
765 /* Lag windowing */
766 for (int i = 1; i <= 4; i++) {
767 /*ac[i] *= exp(-.5*(2*M_PI*.002*i)*(2*M_PI*.002*i));*/
768 ac[i] -= ac[i]*(.008f*i)*(.008f*i);
769 }
770
771 celt_lpc(lpc, ac, 4);
772 for (int i = 0; i < 4; i++) {
773 tmp = .9f * tmp;
774 lpc[i] = (lpc[i] * tmp);
775 }
776 /* Add a zero */
777 lpc2[0] = lpc[0] + .8f;
778 lpc2[1] = lpc[1] + (c1 * lpc[0]);
779 lpc2[2] = lpc[2] + (c1 * lpc[1]);
780 lpc2[3] = lpc[3] + (c1 * lpc[2]);
781 lpc2[4] = (c1 * lpc[3]);
782 celt_fir5(x_lp, lpc2, x_lp, len>>1, mem);
783 }
784
785 static inline void dual_inner_prod(const float *x, const float *y01, const float *y02,
786 int N, float *xy1, float *xy2)
787 {
788 float xy01 = 0, xy02 = 0;
789
790 for (int i = 0; i < N; i++) {
791 xy01 += (x[i] * y01[i]);
792 xy02 += (x[i] * y02[i]);
793 }
794
795 *xy1 = xy01;
796 *xy2 = xy02;
797 }
798
799 static float compute_pitch_gain(float xy, float xx, float yy)
800 {
801 return xy / sqrtf(1.f + xx * yy);
802 }
803
804 static const uint8_t second_check[16] = {0, 0, 3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2};
805 static float remove_doubling(float *x, int maxperiod, int minperiod, int N,
806 int *T0_, int prev_period, float prev_gain)
807 {
808 int k, i, T, T0;
809 float g, g0;
810 float pg;
811 float xy,xx,yy,xy2;
812 float xcorr[3];
813 float best_xy, best_yy;
814 int offset;
815 int minperiod0;
816 float yy_lookup[PITCH_MAX_PERIOD+1];
817
818 minperiod0 = minperiod;
819 maxperiod /= 2;
820 minperiod /= 2;
821 *T0_ /= 2;
822 prev_period /= 2;
823 N /= 2;
824 x += maxperiod;
825 if (*T0_>=maxperiod)
826 *T0_=maxperiod-1;
827
828 T = T0 = *T0_;
829 dual_inner_prod(x, x, x-T0, N, &xx, &xy);
830 yy_lookup[0] = xx;
831 yy=xx;
832 for (i = 1; i <= maxperiod; i++) {
833 yy = yy+(x[-i] * x[-i])-(x[N-i] * x[N-i]);
834 yy_lookup[i] = FFMAX(0, yy);
835 }
836 yy = yy_lookup[T0];
837 best_xy = xy;
838 best_yy = yy;
839 g = g0 = compute_pitch_gain(xy, xx, yy);
840 /* Look for any pitch at T/k */
841 for (k = 2; k <= 15; k++) {
842 int T1, T1b;
843 float g1;
844 float cont=0;
845 float thresh;
846 T1 = (2*T0+k)/(2*k);
847 if (T1 < minperiod)
848 break;
849 /* Look for another strong correlation at T1b */
850 if (k==2)
851 {
852 if (T1+T0>maxperiod)
853 T1b = T0;
854 else
855 T1b = T0+T1;
856 } else
857 {
858 T1b = (2*second_check[k]*T0+k)/(2*k);
859 }
860 dual_inner_prod(x, &x[-T1], &x[-T1b], N, &xy, &xy2);
861 xy = .5f * (xy + xy2);
862 yy = .5f * (yy_lookup[T1] + yy_lookup[T1b]);
863 g1 = compute_pitch_gain(xy, xx, yy);
864 if (FFABS(T1-prev_period)<=1)
865 cont = prev_gain;
866 else if (FFABS(T1-prev_period)<=2 && 5 * k * k < T0)
867 cont = prev_gain * .5f;
868 else
869 cont = 0;
870 thresh = FFMAX(.3f, (.7f * g0) - cont);
871 /* Bias against very high pitch (very short period) to avoid false-positives
872 due to short-term correlation */
873 if (T1<3*minperiod)
874 thresh = FFMAX(.4f, (.85f * g0) - cont);
875 else if (T1<2*minperiod)
876 thresh = FFMAX(.5f, (.9f * g0) - cont);
877 if (g1 > thresh)
878 {
879 best_xy = xy;
880 best_yy = yy;
881 T = T1;
882 g = g1;
883 }
884 }
885 best_xy = FFMAX(0, best_xy);
886 if (best_yy <= best_xy)
887 pg = Q15ONE;
888 else
889 pg = best_xy/(best_yy + 1);
890
891 for (k = 0; k < 3; k++)
892 xcorr[k] = celt_inner_prod(x, x-(T+k-1), N);
893 if ((xcorr[2]-xcorr[0]) > .7f * (xcorr[1]-xcorr[0]))
894 offset = 1;
895 else if ((xcorr[0]-xcorr[2]) > (.7f * (xcorr[1] - xcorr[2])))
896 offset = -1;
897 else
898 offset = 0;
899 if (pg > g)
900 pg = g;
901 *T0_ = 2*T+offset;
902
903 if (*T0_<minperiod0)
904 *T0_=minperiod0;
905 return pg;
906 }
907
908 static void find_best_pitch(float *xcorr, float *y, int len,
909 int max_pitch, int *best_pitch)
910 {
911 float best_num[2];
912 float best_den[2];
913 float Syy = 1.f;
914
915 best_num[0] = -1;
916 best_num[1] = -1;
917 best_den[0] = 0;
918 best_den[1] = 0;
919 best_pitch[0] = 0;
920 best_pitch[1] = 1;
921
922 for (int j = 0; j < len; j++)
923 Syy += y[j] * y[j];
924
925 for (int i = 0; i < max_pitch; i++) {
926 if (xcorr[i]>0) {
927 float num;
928 float xcorr16;
929
930 xcorr16 = xcorr[i];
931 /* Considering the range of xcorr16, this should avoid both underflows
932 and overflows (inf) when squaring xcorr16 */
933 xcorr16 *= 1e-12f;
934 num = xcorr16 * xcorr16;
935 if ((num * best_den[1]) > (best_num[1] * Syy)) {
936 if ((num * best_den[0]) > (best_num[0] * Syy)) {
937 best_num[1] = best_num[0];
938 best_den[1] = best_den[0];
939 best_pitch[1] = best_pitch[0];
940 best_num[0] = num;
941 best_den[0] = Syy;
942 best_pitch[0] = i;
943 } else {
944 best_num[1] = num;
945 best_den[1] = Syy;
946 best_pitch[1] = i;
947 }
948 }
949 }
950 Syy += y[i+len]*y[i+len] - y[i] * y[i];
951 Syy = FFMAX(1, Syy);
952 }
953 }
954
955 static void pitch_search(const float *x_lp, float *y,
956 int len, int max_pitch, int *pitch)
957 {
958 int lag;
959 int best_pitch[2]={0,0};
960 int offset;
961
962 float x_lp4[WINDOW_SIZE];
963 float y_lp4[WINDOW_SIZE];
964 float xcorr[WINDOW_SIZE];
965
966 lag = len+max_pitch;
967
968 /* Downsample by 2 again */
969 for (int j = 0; j < len >> 2; j++)
970 x_lp4[j] = x_lp[2*j];
971 for (int j = 0; j < lag >> 2; j++)
972 y_lp4[j] = y[2*j];
973
974 /* Coarse search with 4x decimation */
975
976 celt_pitch_xcorr(x_lp4, y_lp4, xcorr, len>>2, max_pitch>>2);
977
978 find_best_pitch(xcorr, y_lp4, len>>2, max_pitch>>2, best_pitch);
979
980 /* Finer search with 2x decimation */
981 for (int i = 0; i < max_pitch >> 1; i++) {
982 float sum;
983 xcorr[i] = 0;
984 if (FFABS(i-2*best_pitch[0])>2 && FFABS(i-2*best_pitch[1])>2)
985 continue;
986 sum = celt_inner_prod(x_lp, y+i, len>>1);
987 xcorr[i] = FFMAX(-1, sum);
988 }
989
990 find_best_pitch(xcorr, y, len>>1, max_pitch>>1, best_pitch);
991
992 /* Refine by pseudo-interpolation */
993 if (best_pitch[0] > 0 && best_pitch[0] < (max_pitch >> 1) - 1) {
994 float a, b, c;
995
996 a = xcorr[best_pitch[0] - 1];
997 b = xcorr[best_pitch[0]];
998 c = xcorr[best_pitch[0] + 1];
999 if (c - a > .7f * (b - a))
1000 offset = 1;
1001 else if (a - c > .7f * (b-c))
1002 offset = -1;
1003 else
1004 offset = 0;
1005 } else {
1006 offset = 0;
1007 }
1008
1009 *pitch = 2 * best_pitch[0] - offset;
1010 }
1011
1012 static void dct(AudioRNNContext *s, float *out, const float *in)
1013 {
1014 for (int i = 0; i < NB_BANDS; i++) {
1015 float sum;
1016
1017 sum = s->fdsp->scalarproduct_float(in, s->dct_table[i], FFALIGN(NB_BANDS, 4));
1018 out[i] = sum * sqrtf(2.f / 22);
1019 }
1020 }
1021
1022 static int compute_frame_features(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, AVComplexFloat *P,
1023 float *Ex, float *Ep, float *Exp, float *features, const float *in)
1024 {
1025 float E = 0;
1026 float *ceps_0, *ceps_1, *ceps_2;
1027 float spec_variability = 0;
1028 LOCAL_ALIGNED_32(float, Ly, [NB_BANDS]);
1029 LOCAL_ALIGNED_32(float, p, [WINDOW_SIZE]);
1030 float pitch_buf[PITCH_BUF_SIZE>>1];
1031 int pitch_index;
1032 float gain;
1033 float *(pre[1]);
1034 float tmp[NB_BANDS];
1035 float follow, logMax;
1036
1037 frame_analysis(s, st, X, Ex, in);
1038 RNN_MOVE(st->pitch_buf, &st->pitch_buf[FRAME_SIZE], PITCH_BUF_SIZE-FRAME_SIZE);
1039 RNN_COPY(&st->pitch_buf[PITCH_BUF_SIZE-FRAME_SIZE], in, FRAME_SIZE);
1040 pre[0] = &st->pitch_buf[0];
1041 pitch_downsample(pre, pitch_buf, PITCH_BUF_SIZE, 1);
1042 pitch_search(pitch_buf+(PITCH_MAX_PERIOD>>1), pitch_buf, PITCH_FRAME_SIZE,
1043 PITCH_MAX_PERIOD-3*PITCH_MIN_PERIOD, &pitch_index);
1044 pitch_index = PITCH_MAX_PERIOD-pitch_index;
1045
1046 gain = remove_doubling(pitch_buf, PITCH_MAX_PERIOD, PITCH_MIN_PERIOD,
1047 PITCH_FRAME_SIZE, &pitch_index, st->last_period, st->last_gain);
1048 st->last_period = pitch_index;
1049 st->last_gain = gain;
1050
1051 for (int i = 0; i < WINDOW_SIZE; i++)
1052 p[i] = st->pitch_buf[PITCH_BUF_SIZE-WINDOW_SIZE-pitch_index+i];
1053
1054 s->fdsp->vector_fmul(p, p, s->window, WINDOW_SIZE);
1055 forward_transform(st, P, p);
1056 compute_band_energy(Ep, P);
1057 compute_band_corr(Exp, X, P);
1058
1059 for (int i = 0; i < NB_BANDS; i++)
1060 Exp[i] = Exp[i] / sqrtf(.001f+Ex[i]*Ep[i]);
1061
1062 dct(s, tmp, Exp);
1063
1064 for (int i = 0; i < NB_DELTA_CEPS; i++)
1065 features[NB_BANDS+2*NB_DELTA_CEPS+i] = tmp[i];
1066
1067 features[NB_BANDS+2*NB_DELTA_CEPS] -= 1.3;
1068 features[NB_BANDS+2*NB_DELTA_CEPS+1] -= 0.9;
1069 features[NB_BANDS+3*NB_DELTA_CEPS] = .01*(pitch_index-300);
1070 logMax = -2;
1071 follow = -2;
1072
1073 for (int i = 0; i < NB_BANDS; i++) {
1074 Ly[i] = log10f(1e-2f + Ex[i]);
1075 Ly[i] = FFMAX(logMax-7, FFMAX(follow-1.5, Ly[i]));
1076 logMax = FFMAX(logMax, Ly[i]);
1077 follow = FFMAX(follow-1.5, Ly[i]);
1078 E += Ex[i];
1079 }
1080
1081 if (E < 0.04f) {
1082 /* If there's no audio, avoid messing up the state. */
1083 RNN_CLEAR(features, NB_FEATURES);
1084 return 1;
1085 }
1086
1087 dct(s, features, Ly);
1088 features[0] -= 12;
1089 features[1] -= 4;
1090 ceps_0 = st->cepstral_mem[st->memid];
1091 ceps_1 = (st->memid < 1) ? st->cepstral_mem[CEPS_MEM+st->memid-1] : st->cepstral_mem[st->memid-1];
1092 ceps_2 = (st->memid < 2) ? st->cepstral_mem[CEPS_MEM+st->memid-2] : st->cepstral_mem[st->memid-2];
1093
1094 for (int i = 0; i < NB_BANDS; i++)
1095 ceps_0[i] = features[i];
1096
1097 st->memid++;
1098 for (int i = 0; i < NB_DELTA_CEPS; i++) {
1099 features[i] = ceps_0[i] + ceps_1[i] + ceps_2[i];
1100 features[NB_BANDS+i] = ceps_0[i] - ceps_2[i];
1101 features[NB_BANDS+NB_DELTA_CEPS+i] = ceps_0[i] - 2*ceps_1[i] + ceps_2[i];
1102 }
1103 /* Spectral variability features. */
1104 if (st->memid == CEPS_MEM)
1105 st->memid = 0;
1106
1107 for (int i = 0; i < CEPS_MEM; i++) {
1108 float mindist = 1e15f;
1109 for (int j = 0; j < CEPS_MEM; j++) {
1110 float dist = 0.f;
1111 for (int k = 0; k < NB_BANDS; k++) {
1112 float tmp;
1113
1114 tmp = st->cepstral_mem[i][k] - st->cepstral_mem[j][k];
1115 dist += tmp*tmp;
1116 }
1117
1118 if (j != i)
1119 mindist = FFMIN(mindist, dist);
1120 }
1121
1122 spec_variability += mindist;
1123 }
1124
1125 features[NB_BANDS+3*NB_DELTA_CEPS+1] = spec_variability/CEPS_MEM-2.1;
1126
1127 return 0;
1128 }
1129
1130 static void interp_band_gain(float *g, const float *bandE)
1131 {
1132 memset(g, 0, sizeof(*g) * FREQ_SIZE);
1133
1134 for (int i = 0; i < NB_BANDS - 1; i++) {
1135 const int band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
1136
1137 for (int j = 0; j < band_size; j++) {
1138 float frac = (float)j / band_size;
1139
1140 g[(eband5ms[i] << FRAME_SIZE_SHIFT) + j] = (1.f - frac) * bandE[i] + frac * bandE[i + 1];
1141 }
1142 }
1143 }
1144
1145 static void pitch_filter(AVComplexFloat *X, const AVComplexFloat *P, const float *Ex, const float *Ep,
1146 const float *Exp, const float *g)
1147 {
1148 float newE[NB_BANDS];
1149 float r[NB_BANDS];
1150 float norm[NB_BANDS];
1151 float rf[FREQ_SIZE] = {0};
1152 float normf[FREQ_SIZE]={0};
1153
1154 for (int i = 0; i < NB_BANDS; i++) {
1155 if (Exp[i]>g[i]) r[i] = 1;
1156 else r[i] = SQUARE(Exp[i])*(1-SQUARE(g[i]))/(.001 + SQUARE(g[i])*(1-SQUARE(Exp[i])));
1157 r[i] = sqrtf(av_clipf(r[i], 0, 1));
1158 r[i] *= sqrtf(Ex[i]/(1e-8+Ep[i]));
1159 }
1160 interp_band_gain(rf, r);
1161 for (int i = 0; i < FREQ_SIZE; i++) {
1162 X[i].re += rf[i]*P[i].re;
1163 X[i].im += rf[i]*P[i].im;
1164 }
1165 compute_band_energy(newE, X);
1166 for (int i = 0; i < NB_BANDS; i++) {
1167 norm[i] = sqrtf(Ex[i] / (1e-8+newE[i]));
1168 }
1169 interp_band_gain(normf, norm);
1170 for (int i = 0; i < FREQ_SIZE; i++) {
1171 X[i].re *= normf[i];
1172 X[i].im *= normf[i];
1173 }
1174 }
1175
1176 static const float tansig_table[201] = {
1177 0.000000f, 0.039979f, 0.079830f, 0.119427f, 0.158649f,
1178 0.197375f, 0.235496f, 0.272905f, 0.309507f, 0.345214f,
1179 0.379949f, 0.413644f, 0.446244f, 0.477700f, 0.507977f,
1180 0.537050f, 0.564900f, 0.591519f, 0.616909f, 0.641077f,
1181 0.664037f, 0.685809f, 0.706419f, 0.725897f, 0.744277f,
1182 0.761594f, 0.777888f, 0.793199f, 0.807569f, 0.821040f,
1183 0.833655f, 0.845456f, 0.856485f, 0.866784f, 0.876393f,
1184 0.885352f, 0.893698f, 0.901468f, 0.908698f, 0.915420f,
1185 0.921669f, 0.927473f, 0.932862f, 0.937863f, 0.942503f,
1186 0.946806f, 0.950795f, 0.954492f, 0.957917f, 0.961090f,
1187 0.964028f, 0.966747f, 0.969265f, 0.971594f, 0.973749f,
1188 0.975743f, 0.977587f, 0.979293f, 0.980869f, 0.982327f,
1189 0.983675f, 0.984921f, 0.986072f, 0.987136f, 0.988119f,
1190 0.989027f, 0.989867f, 0.990642f, 0.991359f, 0.992020f,
1191 0.992631f, 0.993196f, 0.993718f, 0.994199f, 0.994644f,
1192 0.995055f, 0.995434f, 0.995784f, 0.996108f, 0.996407f,
1193 0.996682f, 0.996937f, 0.997172f, 0.997389f, 0.997590f,
1194 0.997775f, 0.997946f, 0.998104f, 0.998249f, 0.998384f,
1195 0.998508f, 0.998623f, 0.998728f, 0.998826f, 0.998916f,
1196 0.999000f, 0.999076f, 0.999147f, 0.999213f, 0.999273f,
1197 0.999329f, 0.999381f, 0.999428f, 0.999472f, 0.999513f,
1198 0.999550f, 0.999585f, 0.999617f, 0.999646f, 0.999673f,
1199 0.999699f, 0.999722f, 0.999743f, 0.999763f, 0.999781f,
1200 0.999798f, 0.999813f, 0.999828f, 0.999841f, 0.999853f,
1201 0.999865f, 0.999875f, 0.999885f, 0.999893f, 0.999902f,
1202 0.999909f, 0.999916f, 0.999923f, 0.999929f, 0.999934f,
1203 0.999939f, 0.999944f, 0.999948f, 0.999952f, 0.999956f,
1204 0.999959f, 0.999962f, 0.999965f, 0.999968f, 0.999970f,
1205 0.999973f, 0.999975f, 0.999977f, 0.999978f, 0.999980f,
1206 0.999982f, 0.999983f, 0.999984f, 0.999986f, 0.999987f,
1207 0.999988f, 0.999989f, 0.999990f, 0.999990f, 0.999991f,
1208 0.999992f, 0.999992f, 0.999993f, 0.999994f, 0.999994f,
1209 0.999994f, 0.999995f, 0.999995f, 0.999996f, 0.999996f,
1210 0.999996f, 0.999997f, 0.999997f, 0.999997f, 0.999997f,
1211 0.999997f, 0.999998f, 0.999998f, 0.999998f, 0.999998f,
1212 0.999998f, 0.999998f, 0.999999f, 0.999999f, 0.999999f,
1213 0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1214 0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1215 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1216 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1217 1.000000f,
1218 };
1219
1220 static inline float tansig_approx(float x)
1221 {
1222 float y, dy;
1223 float sign=1;
1224 int i;
1225
1226 /* Tests are reversed to catch NaNs */
1227 if (!(x<8))
1228 return 1;
1229 if (!(x>-8))
1230 return -1;
1231 /* Another check in case of -ffast-math */
1232
1233 if (isnan(x))
1234 return 0;
1235
1236 if (x < 0) {
1237 x=-x;
1238 sign=-1;
1239 }
1240 i = (int)floor(.5f+25*x);
1241 x -= .04f*i;
1242 y = tansig_table[i];
1243 dy = 1-y*y;
1244 y = y + x*dy*(1 - y*x);
1245 return sign*y;
1246 }
1247
1248 static inline float sigmoid_approx(float x)
1249 {
1250 return .5f + .5f*tansig_approx(.5f*x);
1251 }
1252
1253 static void compute_dense(const DenseLayer *layer, float *output, const float *input)
1254 {
1255 const int N = layer->nb_neurons, M = layer->nb_inputs, stride = N;
1256
1257 for (int i = 0; i < N; i++) {
1258 /* Compute update gate. */
1259 float sum = layer->bias[i];
1260
1261 for (int j = 0; j < M; j++)
1262 sum += layer->input_weights[j * stride + i] * input[j];
1263
1264 output[i] = WEIGHTS_SCALE * sum;
1265 }
1266
1267 if (layer->activation == ACTIVATION_SIGMOID) {
1268 for (int i = 0; i < N; i++)
1269 output[i] = sigmoid_approx(output[i]);
1270 } else if (layer->activation == ACTIVATION_TANH) {
1271 for (int i = 0; i < N; i++)
1272 output[i] = tansig_approx(output[i]);
1273 } else if (layer->activation == ACTIVATION_RELU) {
1274 for (int i = 0; i < N; i++)
1275 output[i] = FFMAX(0, output[i]);
1276 } else {
1277 av_assert0(0);
1278 }
1279 }
1280
1281 static void compute_gru(AudioRNNContext *s, const GRULayer *gru, float *state, const float *input)
1282 {
1283 LOCAL_ALIGNED_32(float, z, [MAX_NEURONS]);
1284 LOCAL_ALIGNED_32(float, r, [MAX_NEURONS]);
1285 LOCAL_ALIGNED_32(float, h, [MAX_NEURONS]);
1286 const int M = gru->nb_inputs;
1287 const int N = gru->nb_neurons;
1288 const int AN = FFALIGN(N, 4);
1289 const int AM = FFALIGN(M, 4);
1290 const int stride = 3 * AN, istride = 3 * AM;
1291
1292 for (int i = 0; i < N; i++) {
1293 /* Compute update gate. */
1294 float sum = gru->bias[i];
1295
1296 sum += s->fdsp->scalarproduct_float(gru->input_weights + i * istride, input, AM);
1297 sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + i * stride, state, AN);
1298 z[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1299 }
1300
1301 for (int i = 0; i < N; i++) {
1302 /* Compute reset gate. */
1303 float sum = gru->bias[N + i];
1304
1305 sum += s->fdsp->scalarproduct_float(gru->input_weights + AM + i * istride, input, AM);
1306 sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + AN + i * stride, state, AN);
1307 r[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1308 }
1309
1310 for (int i = 0; i < N; i++) {
1311 /* Compute output. */
1312 float sum = gru->bias[2 * N + i];
1313
1314 sum += s->fdsp->scalarproduct_float(gru->input_weights + 2 * AM + i * istride, input, AM);
1315 for (int j = 0; j < N; j++)
1316 sum += gru->recurrent_weights[2 * AN + i * stride + j] * state[j] * r[j];
1317
1318 if (gru->activation == ACTIVATION_SIGMOID)
1319 sum = sigmoid_approx(WEIGHTS_SCALE * sum);
1320 else if (gru->activation == ACTIVATION_TANH)
1321 sum = tansig_approx(WEIGHTS_SCALE * sum);
1322 else if (gru->activation == ACTIVATION_RELU)
1323 sum = FFMAX(0, WEIGHTS_SCALE * sum);
1324 else
1325 av_assert0(0);
1326 h[i] = z[i] * state[i] + (1.f - z[i]) * sum;
1327 }
1328
1329 RNN_COPY(state, h, N);
1330 }
1331
1332 #define INPUT_SIZE 42
1333
1334 static void compute_rnn(AudioRNNContext *s, RNNState *rnn, float *gains, float *vad, const float *input)
1335 {
1336 LOCAL_ALIGNED_32(float, dense_out, [MAX_NEURONS]);
1337 LOCAL_ALIGNED_32(float, noise_input, [MAX_NEURONS * 3]);
1338 LOCAL_ALIGNED_32(float, denoise_input, [MAX_NEURONS * 3]);
1339
1340 compute_dense(rnn->model->input_dense, dense_out, input);
1341 compute_gru(s, rnn->model->vad_gru, rnn->vad_gru_state, dense_out);
1342 compute_dense(rnn->model->vad_output, vad, rnn->vad_gru_state);
1343
1344 memcpy(noise_input, dense_out, rnn->model->input_dense_size * sizeof(float));
1345 memcpy(noise_input + rnn->model->input_dense_size,
1346 rnn->vad_gru_state, rnn->model->vad_gru_size * sizeof(float));
1347 memcpy(noise_input + rnn->model->input_dense_size + rnn->model->vad_gru_size,
1348 input, INPUT_SIZE * sizeof(float));
1349
1350 compute_gru(s, rnn->model->noise_gru, rnn->noise_gru_state, noise_input);
1351
1352 memcpy(denoise_input, rnn->vad_gru_state, rnn->model->vad_gru_size * sizeof(float));
1353 memcpy(denoise_input + rnn->model->vad_gru_size,
1354 rnn->noise_gru_state, rnn->model->noise_gru_size * sizeof(float));
1355 memcpy(denoise_input + rnn->model->vad_gru_size + rnn->model->noise_gru_size,
1356 input, INPUT_SIZE * sizeof(float));
1357
1358 compute_gru(s, rnn->model->denoise_gru, rnn->denoise_gru_state, denoise_input);
1359 compute_dense(rnn->model->denoise_output, gains, rnn->denoise_gru_state);
1360 }
1361
1362 static float rnnoise_channel(AudioRNNContext *s, DenoiseState *st, float *out, const float *in,
1363 int disabled)
1364 {
1365 AVComplexFloat X[FREQ_SIZE];
1366 AVComplexFloat P[WINDOW_SIZE];
1367 float x[FRAME_SIZE];
1368 float Ex[NB_BANDS], Ep[NB_BANDS];
1369 LOCAL_ALIGNED_32(float, Exp, [NB_BANDS]);
1370 float features[NB_FEATURES];
1371 float g[NB_BANDS];
1372 float gf[FREQ_SIZE];
1373 float vad_prob = 0;
1374 float *history = st->history;
1375 static const float a_hp[2] = {-1.99599, 0.99600};
1376 static const float b_hp[2] = {-2, 1};
1377 int silence;
1378
1379 biquad(x, st->mem_hp_x, in, b_hp, a_hp, FRAME_SIZE);
1380 silence = compute_frame_features(s, st, X, P, Ex, Ep, Exp, features, x);
1381
1382 if (!silence && !disabled) {
1383 compute_rnn(s, &st->rnn[0], g, &vad_prob, features);
1384 pitch_filter(X, P, Ex, Ep, Exp, g);
1385 for (int i = 0; i < NB_BANDS; i++) {
1386 float alpha = .6f;
1387
1388 g[i] = FFMAX(g[i], alpha * st->lastg[i]);
1389 st->lastg[i] = g[i];
1390 }
1391
1392 interp_band_gain(gf, g);
1393
1394 for (int i = 0; i < FREQ_SIZE; i++) {
1395 X[i].re *= gf[i];
1396 X[i].im *= gf[i];
1397 }
1398 }
1399
1400 frame_synthesis(s, st, out, X);
1401 memcpy(history, in, FRAME_SIZE * sizeof(*history));
1402
1403 return vad_prob;
1404 }
1405
1406 typedef struct ThreadData {
1407 AVFrame *in, *out;
1408 } ThreadData;
1409
1410 static int rnnoise_channels(AVFilterContext *ctx, void *arg, int jobnr, int nb_jobs)
1411 {
1412 AudioRNNContext *s = ctx->priv;
1413 ThreadData *td = arg;
1414 AVFrame *in = td->in;
1415 AVFrame *out = td->out;
1416 const int start = (out->ch_layout.nb_channels * jobnr) / nb_jobs;
1417 const int end = (out->ch_layout.nb_channels * (jobnr+1)) / nb_jobs;
1418
1419 for (int ch = start; ch < end; ch++) {
1420 rnnoise_channel(s, &s->st[ch],
1421 (float *)out->extended_data[ch],
1422 (const float *)in->extended_data[ch],
1423 ctx->is_disabled);
1424 }
1425
1426 return 0;
1427 }
1428
1429 static int filter_frame(AVFilterLink *inlink, AVFrame *in)
1430 {
1431 AVFilterContext *ctx = inlink->dst;
1432 AVFilterLink *outlink = ctx->outputs[0];
1433 AVFrame *out = NULL;
1434 ThreadData td;
1435
1436 out = ff_get_audio_buffer(outlink, FRAME_SIZE);
1437 if (!out) {
1438 av_frame_free(&in);
1439 return AVERROR(ENOMEM);
1440 }
1441 av_frame_copy_props(out, in);
1442
1443 td.in = in; td.out = out;
1444 ff_filter_execute(ctx, rnnoise_channels, &td, NULL,
1445 FFMIN(outlink->ch_layout.nb_channels, ff_filter_get_nb_threads(ctx)));
1446
1447 av_frame_free(&in);
1448 return ff_filter_frame(outlink, out);
1449 }
1450
1451 static int activate(AVFilterContext *ctx)
1452 {
1453 AVFilterLink *inlink = ctx->inputs[0];
1454 AVFilterLink *outlink = ctx->outputs[0];
1455 AVFrame *in = NULL;
1456 int ret;
1457
1458 FF_FILTER_FORWARD_STATUS_BACK(outlink, inlink);
1459
1460 ret = ff_inlink_consume_samples(inlink, FRAME_SIZE, FRAME_SIZE, &in);
1461 if (ret < 0)
1462 return ret;
1463
1464 if (ret > 0)
1465 return filter_frame(inlink, in);
1466
1467 FF_FILTER_FORWARD_STATUS(inlink, outlink);
1468 FF_FILTER_FORWARD_WANTED(outlink, inlink);
1469
1470 return FFERROR_NOT_READY;
1471 }
1472
1473 static int open_model(AVFilterContext *ctx, RNNModel **model)
1474 {
1475 AudioRNNContext *s = ctx->priv;
1476 int ret;
1477 FILE *f;
1478
1479 if (!s->model_name)
1480 return AVERROR(EINVAL);
1481 f = avpriv_fopen_utf8(s->model_name, "r");
1482 if (!f) {
1483 av_log(ctx, AV_LOG_ERROR, "Failed to open model file: %s\n", s->model_name);
1484 return AVERROR(EINVAL);
1485 }
1486
1487 ret = rnnoise_model_from_file(f, model);
1488 fclose(f);
1489 if (!*model || ret < 0)
1490 return ret;
1491
1492 return 0;
1493 }
1494
1495 static av_cold int init(AVFilterContext *ctx)
1496 {
1497 AudioRNNContext *s = ctx->priv;
1498 int ret;
1499
1500 s->fdsp = avpriv_float_dsp_alloc(0);
1501 if (!s->fdsp)
1502 return AVERROR(ENOMEM);
1503
1504 ret = open_model(ctx, &s->model[0]);
1505 if (ret < 0)
1506 return ret;
1507
1508 for (int i = 0; i < FRAME_SIZE; i++) {
1509 s->window[i] = sin(.5*M_PI*sin(.5*M_PI*(i+.5)/FRAME_SIZE) * sin(.5*M_PI*(i+.5)/FRAME_SIZE));
1510 s->window[WINDOW_SIZE - 1 - i] = s->window[i];
1511 }
1512
1513 for (int i = 0; i < NB_BANDS; i++) {
1514 for (int j = 0; j < NB_BANDS; j++) {
1515 s->dct_table[j][i] = cosf((i + .5f) * j * M_PI / NB_BANDS);
1516 if (j == 0)
1517 s->dct_table[j][i] *= sqrtf(.5);
1518 }
1519 }
1520
1521 return 0;
1522 }
1523
1524 static void free_model(AVFilterContext *ctx, int n)
1525 {
1526 AudioRNNContext *s = ctx->priv;
1527
1528 rnnoise_model_free(s->model[n]);
1529 s->model[n] = NULL;
1530
1531 for (int ch = 0; ch < s->channels && s->st; ch++) {
1532 av_freep(&s->st[ch].rnn[n].vad_gru_state);
1533 av_freep(&s->st[ch].rnn[n].noise_gru_state);
1534 av_freep(&s->st[ch].rnn[n].denoise_gru_state);
1535 }
1536 }
1537
1538 static int process_command(AVFilterContext *ctx, const char *cmd, const char *args,
1539 char *res, int res_len, int flags)
1540 {
1541 AudioRNNContext *s = ctx->priv;
1542 int ret;
1543
1544 ret = ff_filter_process_command(ctx, cmd, args, res, res_len, flags);
1545 if (ret < 0)
1546 return ret;
1547
1548 ret = open_model(ctx, &s->model[1]);
1549 if (ret < 0)
1550 return ret;
1551
1552 FFSWAP(RNNModel *, s->model[0], s->model[1]);
1553 for (int ch = 0; ch < s->channels; ch++)
1554 FFSWAP(RNNState, s->st[ch].rnn[0], s->st[ch].rnn[1]);
1555
1556 ret = config_input(ctx->inputs[0]);
1557 if (ret < 0) {
1558 for (int ch = 0; ch < s->channels; ch++)
1559 FFSWAP(RNNState, s->st[ch].rnn[0], s->st[ch].rnn[1]);
1560 FFSWAP(RNNModel *, s->model[0], s->model[1]);
1561 return ret;
1562 }
1563
1564 free_model(ctx, 1);
1565 return 0;
1566 }
1567
1568 static av_cold void uninit(AVFilterContext *ctx)
1569 {
1570 AudioRNNContext *s = ctx->priv;
1571
1572 av_freep(&s->fdsp);
1573 free_model(ctx, 0);
1574 for (int ch = 0; ch < s->channels && s->st; ch++) {
1575 av_tx_uninit(&s->st[ch].tx);
1576 av_tx_uninit(&s->st[ch].txi);
1577 }
1578 av_freep(&s->st);
1579 }
1580
1581 static const AVFilterPad inputs[] = {
1582 {
1583 .name = "default",
1584 .type = AVMEDIA_TYPE_AUDIO,
1585 .config_props = config_input,
1586 },
1587 };
1588
1589 #define OFFSET(x) offsetof(AudioRNNContext, x)
1590 #define AF AV_OPT_FLAG_AUDIO_PARAM|AV_OPT_FLAG_FILTERING_PARAM|AV_OPT_FLAG_RUNTIME_PARAM
1591
1592 static const AVOption arnndn_options[] = {
1593 { "model", "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1594 { "m", "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1595 { "mix", "set output vs input mix", OFFSET(mix), AV_OPT_TYPE_FLOAT, {.dbl=1.0},-1, 1, AF },
1596 { NULL }
1597 };
1598
1599 AVFILTER_DEFINE_CLASS(arnndn);
1600
1601 const AVFilter ff_af_arnndn = {
1602 .name = "arnndn",
1603 .description = NULL_IF_CONFIG_SMALL("Reduce noise from speech using Recurrent Neural Networks."),
1604 .priv_size = sizeof(AudioRNNContext),
1605 .priv_class = &arnndn_class,
1606 .activate = activate,
1607 .init = init,
1608 .uninit = uninit,
1609 FILTER_INPUTS(inputs),
1610 FILTER_OUTPUTS(ff_audio_default_filterpad),
1611 FILTER_QUERY_FUNC(query_formats),
1612 .flags = AVFILTER_FLAG_SUPPORT_TIMELINE_INTERNAL |
1613 AVFILTER_FLAG_SLICE_THREADS,
1614 .process_command = process_command,
1615 };
1616