FFmpeg coverage


Directory: ../../../ffmpeg/
File: src/libavfilter/af_arnndn.c
Date: 2022-11-26 13:19:19
Exec Total Coverage
Lines: 0 738 0.0%
Branches: 0 700 0.0%

Line Branch Exec Source
1 /*
2 * Copyright (c) 2018 Gregor Richards
3 * Copyright (c) 2017 Mozilla
4 * Copyright (c) 2005-2009 Xiph.Org Foundation
5 * Copyright (c) 2007-2008 CSIRO
6 * Copyright (c) 2008-2011 Octasic Inc.
7 * Copyright (c) Jean-Marc Valin
8 * Copyright (c) 2019 Paul B Mahol
9 *
10 * Redistribution and use in source and binary forms, with or without
11 * modification, are permitted provided that the following conditions
12 * are met:
13 *
14 * - Redistributions of source code must retain the above copyright
15 * notice, this list of conditions and the following disclaimer.
16 *
17 * - Redistributions in binary form must reproduce the above copyright
18 * notice, this list of conditions and the following disclaimer in the
19 * documentation and/or other materials provided with the distribution.
20 *
21 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22 * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR
25 * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26 * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28 * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29 * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30 * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 */
33
34 #include "libavutil/avassert.h"
35 #include "libavutil/file_open.h"
36 #include "libavutil/float_dsp.h"
37 #include "libavutil/mem_internal.h"
38 #include "libavutil/opt.h"
39 #include "libavutil/tx.h"
40 #include "avfilter.h"
41 #include "audio.h"
42 #include "filters.h"
43 #include "formats.h"
44
45 #define FRAME_SIZE_SHIFT 2
46 #define FRAME_SIZE (120<<FRAME_SIZE_SHIFT)
47 #define WINDOW_SIZE (2*FRAME_SIZE)
48 #define FREQ_SIZE (FRAME_SIZE + 1)
49
50 #define PITCH_MIN_PERIOD 60
51 #define PITCH_MAX_PERIOD 768
52 #define PITCH_FRAME_SIZE 960
53 #define PITCH_BUF_SIZE (PITCH_MAX_PERIOD+PITCH_FRAME_SIZE)
54
55 #define SQUARE(x) ((x)*(x))
56
57 #define NB_BANDS 22
58
59 #define CEPS_MEM 8
60 #define NB_DELTA_CEPS 6
61
62 #define NB_FEATURES (NB_BANDS+3*NB_DELTA_CEPS+2)
63
64 #define WEIGHTS_SCALE (1.f/256)
65
66 #define MAX_NEURONS 128
67
68 #define ACTIVATION_TANH 0
69 #define ACTIVATION_SIGMOID 1
70 #define ACTIVATION_RELU 2
71
72 #define Q15ONE 1.0f
73
74 typedef struct DenseLayer {
75 const float *bias;
76 const float *input_weights;
77 int nb_inputs;
78 int nb_neurons;
79 int activation;
80 } DenseLayer;
81
82 typedef struct GRULayer {
83 const float *bias;
84 const float *input_weights;
85 const float *recurrent_weights;
86 int nb_inputs;
87 int nb_neurons;
88 int activation;
89 } GRULayer;
90
91 typedef struct RNNModel {
92 int input_dense_size;
93 const DenseLayer *input_dense;
94
95 int vad_gru_size;
96 const GRULayer *vad_gru;
97
98 int noise_gru_size;
99 const GRULayer *noise_gru;
100
101 int denoise_gru_size;
102 const GRULayer *denoise_gru;
103
104 int denoise_output_size;
105 const DenseLayer *denoise_output;
106
107 int vad_output_size;
108 const DenseLayer *vad_output;
109 } RNNModel;
110
111 typedef struct RNNState {
112 float *vad_gru_state;
113 float *noise_gru_state;
114 float *denoise_gru_state;
115 RNNModel *model;
116 } RNNState;
117
118 typedef struct DenoiseState {
119 float analysis_mem[FRAME_SIZE];
120 float cepstral_mem[CEPS_MEM][NB_BANDS];
121 int memid;
122 DECLARE_ALIGNED(32, float, synthesis_mem)[FRAME_SIZE];
123 float pitch_buf[PITCH_BUF_SIZE];
124 float pitch_enh_buf[PITCH_BUF_SIZE];
125 float last_gain;
126 int last_period;
127 float mem_hp_x[2];
128 float lastg[NB_BANDS];
129 float history[FRAME_SIZE];
130 RNNState rnn[2];
131 AVTXContext *tx, *txi;
132 av_tx_fn tx_fn, txi_fn;
133 } DenoiseState;
134
135 typedef struct AudioRNNContext {
136 const AVClass *class;
137
138 char *model_name;
139 float mix;
140
141 int channels;
142 DenoiseState *st;
143
144 DECLARE_ALIGNED(32, float, window)[WINDOW_SIZE];
145 DECLARE_ALIGNED(32, float, dct_table)[FFALIGN(NB_BANDS, 4)][FFALIGN(NB_BANDS, 4)];
146
147 RNNModel *model[2];
148
149 AVFloatDSPContext *fdsp;
150 } AudioRNNContext;
151
152 #define F_ACTIVATION_TANH 0
153 #define F_ACTIVATION_SIGMOID 1
154 #define F_ACTIVATION_RELU 2
155
156 static void rnnoise_model_free(RNNModel *model)
157 {
158 #define FREE_MAYBE(ptr) do { if (ptr) free(ptr); } while (0)
159 #define FREE_DENSE(name) do { \
160 if (model->name) { \
161 av_free((void *) model->name->input_weights); \
162 av_free((void *) model->name->bias); \
163 av_free((void *) model->name); \
164 } \
165 } while (0)
166 #define FREE_GRU(name) do { \
167 if (model->name) { \
168 av_free((void *) model->name->input_weights); \
169 av_free((void *) model->name->recurrent_weights); \
170 av_free((void *) model->name->bias); \
171 av_free((void *) model->name); \
172 } \
173 } while (0)
174
175 if (!model)
176 return;
177 FREE_DENSE(input_dense);
178 FREE_GRU(vad_gru);
179 FREE_GRU(noise_gru);
180 FREE_GRU(denoise_gru);
181 FREE_DENSE(denoise_output);
182 FREE_DENSE(vad_output);
183 av_free(model);
184 }
185
186 static int rnnoise_model_from_file(FILE *f, RNNModel **rnn)
187 {
188 RNNModel *ret = NULL;
189 DenseLayer *input_dense;
190 GRULayer *vad_gru;
191 GRULayer *noise_gru;
192 GRULayer *denoise_gru;
193 DenseLayer *denoise_output;
194 DenseLayer *vad_output;
195 int in;
196
197 if (fscanf(f, "rnnoise-nu model file version %d\n", &in) != 1 || in != 1)
198 return AVERROR_INVALIDDATA;
199
200 ret = av_calloc(1, sizeof(RNNModel));
201 if (!ret)
202 return AVERROR(ENOMEM);
203
204 #define ALLOC_LAYER(type, name) \
205 name = av_calloc(1, sizeof(type)); \
206 if (!name) { \
207 rnnoise_model_free(ret); \
208 return AVERROR(ENOMEM); \
209 } \
210 ret->name = name
211
212 ALLOC_LAYER(DenseLayer, input_dense);
213 ALLOC_LAYER(GRULayer, vad_gru);
214 ALLOC_LAYER(GRULayer, noise_gru);
215 ALLOC_LAYER(GRULayer, denoise_gru);
216 ALLOC_LAYER(DenseLayer, denoise_output);
217 ALLOC_LAYER(DenseLayer, vad_output);
218
219 #define INPUT_VAL(name) do { \
220 if (fscanf(f, "%d", &in) != 1 || in < 0 || in > 128) { \
221 rnnoise_model_free(ret); \
222 return AVERROR(EINVAL); \
223 } \
224 name = in; \
225 } while (0)
226
227 #define INPUT_ACTIVATION(name) do { \
228 int activation; \
229 INPUT_VAL(activation); \
230 switch (activation) { \
231 case F_ACTIVATION_SIGMOID: \
232 name = ACTIVATION_SIGMOID; \
233 break; \
234 case F_ACTIVATION_RELU: \
235 name = ACTIVATION_RELU; \
236 break; \
237 default: \
238 name = ACTIVATION_TANH; \
239 } \
240 } while (0)
241
242 #define INPUT_ARRAY(name, len) do { \
243 float *values = av_calloc((len), sizeof(float)); \
244 if (!values) { \
245 rnnoise_model_free(ret); \
246 return AVERROR(ENOMEM); \
247 } \
248 name = values; \
249 for (int i = 0; i < (len); i++) { \
250 if (fscanf(f, "%d", &in) != 1) { \
251 rnnoise_model_free(ret); \
252 return AVERROR(EINVAL); \
253 } \
254 values[i] = in; \
255 } \
256 } while (0)
257
258 #define INPUT_ARRAY3(name, len0, len1, len2) do { \
259 float *values = av_calloc(FFALIGN((len0), 4) * FFALIGN((len1), 4) * (len2), sizeof(float)); \
260 if (!values) { \
261 rnnoise_model_free(ret); \
262 return AVERROR(ENOMEM); \
263 } \
264 name = values; \
265 for (int k = 0; k < (len0); k++) { \
266 for (int i = 0; i < (len2); i++) { \
267 for (int j = 0; j < (len1); j++) { \
268 if (fscanf(f, "%d", &in) != 1) { \
269 rnnoise_model_free(ret); \
270 return AVERROR(EINVAL); \
271 } \
272 values[j * (len2) * FFALIGN((len0), 4) + i * FFALIGN((len0), 4) + k] = in; \
273 } \
274 } \
275 } \
276 } while (0)
277
278 #define NEW_LINE() do { \
279 int c; \
280 while ((c = fgetc(f)) != EOF) { \
281 if (c == '\n') \
282 break; \
283 } \
284 } while (0)
285
286 #define INPUT_DENSE(name) do { \
287 INPUT_VAL(name->nb_inputs); \
288 INPUT_VAL(name->nb_neurons); \
289 ret->name ## _size = name->nb_neurons; \
290 INPUT_ACTIVATION(name->activation); \
291 NEW_LINE(); \
292 INPUT_ARRAY(name->input_weights, name->nb_inputs * name->nb_neurons); \
293 NEW_LINE(); \
294 INPUT_ARRAY(name->bias, name->nb_neurons); \
295 NEW_LINE(); \
296 } while (0)
297
298 #define INPUT_GRU(name) do { \
299 INPUT_VAL(name->nb_inputs); \
300 INPUT_VAL(name->nb_neurons); \
301 ret->name ## _size = name->nb_neurons; \
302 INPUT_ACTIVATION(name->activation); \
303 NEW_LINE(); \
304 INPUT_ARRAY3(name->input_weights, name->nb_inputs, name->nb_neurons, 3); \
305 NEW_LINE(); \
306 INPUT_ARRAY3(name->recurrent_weights, name->nb_neurons, name->nb_neurons, 3); \
307 NEW_LINE(); \
308 INPUT_ARRAY(name->bias, name->nb_neurons * 3); \
309 NEW_LINE(); \
310 } while (0)
311
312 INPUT_DENSE(input_dense);
313 INPUT_GRU(vad_gru);
314 INPUT_GRU(noise_gru);
315 INPUT_GRU(denoise_gru);
316 INPUT_DENSE(denoise_output);
317 INPUT_DENSE(vad_output);
318
319 if (vad_output->nb_neurons != 1) {
320 rnnoise_model_free(ret);
321 return AVERROR(EINVAL);
322 }
323
324 *rnn = ret;
325
326 return 0;
327 }
328
329 static int query_formats(AVFilterContext *ctx)
330 {
331 static const enum AVSampleFormat sample_fmts[] = {
332 AV_SAMPLE_FMT_FLTP,
333 AV_SAMPLE_FMT_NONE
334 };
335 int ret, sample_rates[] = { 48000, -1 };
336
337 ret = ff_set_common_formats_from_list(ctx, sample_fmts);
338 if (ret < 0)
339 return ret;
340
341 ret = ff_set_common_all_channel_counts(ctx);
342 if (ret < 0)
343 return ret;
344
345 return ff_set_common_samplerates_from_list(ctx, sample_rates);
346 }
347
348 static int config_input(AVFilterLink *inlink)
349 {
350 AVFilterContext *ctx = inlink->dst;
351 AudioRNNContext *s = ctx->priv;
352 int ret = 0;
353
354 s->channels = inlink->ch_layout.nb_channels;
355
356 if (!s->st)
357 s->st = av_calloc(s->channels, sizeof(DenoiseState));
358 if (!s->st)
359 return AVERROR(ENOMEM);
360
361 for (int i = 0; i < s->channels; i++) {
362 DenoiseState *st = &s->st[i];
363
364 st->rnn[0].model = s->model[0];
365 st->rnn[0].vad_gru_state = av_calloc(sizeof(float), FFALIGN(s->model[0]->vad_gru_size, 16));
366 st->rnn[0].noise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model[0]->noise_gru_size, 16));
367 st->rnn[0].denoise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model[0]->denoise_gru_size, 16));
368 if (!st->rnn[0].vad_gru_state ||
369 !st->rnn[0].noise_gru_state ||
370 !st->rnn[0].denoise_gru_state)
371 return AVERROR(ENOMEM);
372 }
373
374 for (int i = 0; i < s->channels; i++) {
375 DenoiseState *st = &s->st[i];
376
377 if (!st->tx)
378 ret = av_tx_init(&st->tx, &st->tx_fn, AV_TX_FLOAT_FFT, 0, WINDOW_SIZE, NULL, 0);
379 if (ret < 0)
380 return ret;
381
382 if (!st->txi)
383 ret = av_tx_init(&st->txi, &st->txi_fn, AV_TX_FLOAT_FFT, 1, WINDOW_SIZE, NULL, 0);
384 if (ret < 0)
385 return ret;
386 }
387
388 return ret;
389 }
390
391 static void biquad(float *y, float mem[2], const float *x,
392 const float *b, const float *a, int N)
393 {
394 for (int i = 0; i < N; i++) {
395 float xi, yi;
396
397 xi = x[i];
398 yi = x[i] + mem[0];
399 mem[0] = mem[1] + (b[0]*xi - a[0]*yi);
400 mem[1] = (b[1]*xi - a[1]*yi);
401 y[i] = yi;
402 }
403 }
404
405 #define RNN_MOVE(dst, src, n) (memmove((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
406 #define RNN_CLEAR(dst, n) (memset((dst), 0, (n)*sizeof(*(dst))))
407 #define RNN_COPY(dst, src, n) (memcpy((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
408
409 static void forward_transform(DenoiseState *st, AVComplexFloat *out, const float *in)
410 {
411 AVComplexFloat x[WINDOW_SIZE];
412 AVComplexFloat y[WINDOW_SIZE];
413
414 for (int i = 0; i < WINDOW_SIZE; i++) {
415 x[i].re = in[i];
416 x[i].im = 0;
417 }
418
419 st->tx_fn(st->tx, y, x, sizeof(float));
420
421 RNN_COPY(out, y, FREQ_SIZE);
422 }
423
424 static void inverse_transform(DenoiseState *st, float *out, const AVComplexFloat *in)
425 {
426 AVComplexFloat x[WINDOW_SIZE];
427 AVComplexFloat y[WINDOW_SIZE];
428
429 RNN_COPY(x, in, FREQ_SIZE);
430
431 for (int i = FREQ_SIZE; i < WINDOW_SIZE; i++) {
432 x[i].re = x[WINDOW_SIZE - i].re;
433 x[i].im = -x[WINDOW_SIZE - i].im;
434 }
435
436 st->txi_fn(st->txi, y, x, sizeof(float));
437
438 for (int i = 0; i < WINDOW_SIZE; i++)
439 out[i] = y[i].re / WINDOW_SIZE;
440 }
441
442 static const uint8_t eband5ms[] = {
443 /*0 200 400 600 800 1k 1.2 1.4 1.6 2k 2.4 2.8 3.2 4k 4.8 5.6 6.8 8k 9.6 12k 15.6 20k*/
444 0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 34, 40, 48, 60, 78, 100
445 };
446
447 static void compute_band_energy(float *bandE, const AVComplexFloat *X)
448 {
449 float sum[NB_BANDS] = {0};
450
451 for (int i = 0; i < NB_BANDS - 1; i++) {
452 int band_size;
453
454 band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
455 for (int j = 0; j < band_size; j++) {
456 float tmp, frac = (float)j / band_size;
457
458 tmp = SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].re);
459 tmp += SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].im);
460 sum[i] += (1.f - frac) * tmp;
461 sum[i + 1] += frac * tmp;
462 }
463 }
464
465 sum[0] *= 2;
466 sum[NB_BANDS - 1] *= 2;
467
468 for (int i = 0; i < NB_BANDS; i++)
469 bandE[i] = sum[i];
470 }
471
472 static void compute_band_corr(float *bandE, const AVComplexFloat *X, const AVComplexFloat *P)
473 {
474 float sum[NB_BANDS] = { 0 };
475
476 for (int i = 0; i < NB_BANDS - 1; i++) {
477 int band_size;
478
479 band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
480 for (int j = 0; j < band_size; j++) {
481 float tmp, frac = (float)j / band_size;
482
483 tmp = X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re;
484 tmp += X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im;
485 sum[i] += (1 - frac) * tmp;
486 sum[i + 1] += frac * tmp;
487 }
488 }
489
490 sum[0] *= 2;
491 sum[NB_BANDS-1] *= 2;
492
493 for (int i = 0; i < NB_BANDS; i++)
494 bandE[i] = sum[i];
495 }
496
497 static void frame_analysis(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, float *Ex, const float *in)
498 {
499 LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
500
501 RNN_COPY(x, st->analysis_mem, FRAME_SIZE);
502 RNN_COPY(x + FRAME_SIZE, in, FRAME_SIZE);
503 RNN_COPY(st->analysis_mem, in, FRAME_SIZE);
504 s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
505 forward_transform(st, X, x);
506 compute_band_energy(Ex, X);
507 }
508
509 static void frame_synthesis(AudioRNNContext *s, DenoiseState *st, float *out, const AVComplexFloat *y)
510 {
511 LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
512 const float *src = st->history;
513 const float mix = s->mix;
514 const float imix = 1.f - FFMAX(mix, 0.f);
515
516 inverse_transform(st, x, y);
517 s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
518 s->fdsp->vector_fmac_scalar(x, st->synthesis_mem, 1.f, FRAME_SIZE);
519 RNN_COPY(out, x, FRAME_SIZE);
520 RNN_COPY(st->synthesis_mem, &x[FRAME_SIZE], FRAME_SIZE);
521
522 for (int n = 0; n < FRAME_SIZE; n++)
523 out[n] = out[n] * mix + src[n] * imix;
524 }
525
526 static inline void xcorr_kernel(const float *x, const float *y, float sum[4], int len)
527 {
528 float y_0, y_1, y_2, y_3 = 0;
529 int j;
530
531 y_0 = *y++;
532 y_1 = *y++;
533 y_2 = *y++;
534
535 for (j = 0; j < len - 3; j += 4) {
536 float tmp;
537
538 tmp = *x++;
539 y_3 = *y++;
540 sum[0] += tmp * y_0;
541 sum[1] += tmp * y_1;
542 sum[2] += tmp * y_2;
543 sum[3] += tmp * y_3;
544 tmp = *x++;
545 y_0 = *y++;
546 sum[0] += tmp * y_1;
547 sum[1] += tmp * y_2;
548 sum[2] += tmp * y_3;
549 sum[3] += tmp * y_0;
550 tmp = *x++;
551 y_1 = *y++;
552 sum[0] += tmp * y_2;
553 sum[1] += tmp * y_3;
554 sum[2] += tmp * y_0;
555 sum[3] += tmp * y_1;
556 tmp = *x++;
557 y_2 = *y++;
558 sum[0] += tmp * y_3;
559 sum[1] += tmp * y_0;
560 sum[2] += tmp * y_1;
561 sum[3] += tmp * y_2;
562 }
563
564 if (j++ < len) {
565 float tmp = *x++;
566
567 y_3 = *y++;
568 sum[0] += tmp * y_0;
569 sum[1] += tmp * y_1;
570 sum[2] += tmp * y_2;
571 sum[3] += tmp * y_3;
572 }
573
574 if (j++ < len) {
575 float tmp=*x++;
576
577 y_0 = *y++;
578 sum[0] += tmp * y_1;
579 sum[1] += tmp * y_2;
580 sum[2] += tmp * y_3;
581 sum[3] += tmp * y_0;
582 }
583
584 if (j < len) {
585 float tmp=*x++;
586
587 y_1 = *y++;
588 sum[0] += tmp * y_2;
589 sum[1] += tmp * y_3;
590 sum[2] += tmp * y_0;
591 sum[3] += tmp * y_1;
592 }
593 }
594
595 static inline float celt_inner_prod(const float *x,
596 const float *y, int N)
597 {
598 float xy = 0.f;
599
600 for (int i = 0; i < N; i++)
601 xy += x[i] * y[i];
602
603 return xy;
604 }
605
606 static void celt_pitch_xcorr(const float *x, const float *y,
607 float *xcorr, int len, int max_pitch)
608 {
609 int i;
610
611 for (i = 0; i < max_pitch - 3; i += 4) {
612 float sum[4] = { 0, 0, 0, 0};
613
614 xcorr_kernel(x, y + i, sum, len);
615
616 xcorr[i] = sum[0];
617 xcorr[i + 1] = sum[1];
618 xcorr[i + 2] = sum[2];
619 xcorr[i + 3] = sum[3];
620 }
621 /* In case max_pitch isn't a multiple of 4, do non-unrolled version. */
622 for (; i < max_pitch; i++) {
623 xcorr[i] = celt_inner_prod(x, y + i, len);
624 }
625 }
626
627 static int celt_autocorr(const float *x, /* in: [0...n-1] samples x */
628 float *ac, /* out: [0...lag-1] ac values */
629 const float *window,
630 int overlap,
631 int lag,
632 int n)
633 {
634 int fastN = n - lag;
635 int shift;
636 const float *xptr;
637 float xx[PITCH_BUF_SIZE>>1];
638
639 if (overlap == 0) {
640 xptr = x;
641 } else {
642 for (int i = 0; i < n; i++)
643 xx[i] = x[i];
644 for (int i = 0; i < overlap; i++) {
645 xx[i] = x[i] * window[i];
646 xx[n-i-1] = x[n-i-1] * window[i];
647 }
648 xptr = xx;
649 }
650
651 shift = 0;
652 celt_pitch_xcorr(xptr, xptr, ac, fastN, lag+1);
653
654 for (int k = 0; k <= lag; k++) {
655 float d = 0.f;
656
657 for (int i = k + fastN; i < n; i++)
658 d += xptr[i] * xptr[i-k];
659 ac[k] += d;
660 }
661
662 return shift;
663 }
664
665 static void celt_lpc(float *lpc, /* out: [0...p-1] LPC coefficients */
666 const float *ac, /* in: [0...p] autocorrelation values */
667 int p)
668 {
669 float r, error = ac[0];
670
671 RNN_CLEAR(lpc, p);
672 if (ac[0] != 0) {
673 for (int i = 0; i < p; i++) {
674 /* Sum up this iteration's reflection coefficient */
675 float rr = 0;
676 for (int j = 0; j < i; j++)
677 rr += (lpc[j] * ac[i - j]);
678 rr += ac[i + 1];
679 r = -rr/error;
680 /* Update LPC coefficients and total error */
681 lpc[i] = r;
682 for (int j = 0; j < (i + 1) >> 1; j++) {
683 float tmp1, tmp2;
684 tmp1 = lpc[j];
685 tmp2 = lpc[i-1-j];
686 lpc[j] = tmp1 + (r*tmp2);
687 lpc[i-1-j] = tmp2 + (r*tmp1);
688 }
689
690 error = error - (r * r *error);
691 /* Bail out once we get 30 dB gain */
692 if (error < .001f * ac[0])
693 break;
694 }
695 }
696 }
697
698 static void celt_fir5(const float *x,
699 const float *num,
700 float *y,
701 int N,
702 float *mem)
703 {
704 float num0, num1, num2, num3, num4;
705 float mem0, mem1, mem2, mem3, mem4;
706
707 num0 = num[0];
708 num1 = num[1];
709 num2 = num[2];
710 num3 = num[3];
711 num4 = num[4];
712 mem0 = mem[0];
713 mem1 = mem[1];
714 mem2 = mem[2];
715 mem3 = mem[3];
716 mem4 = mem[4];
717
718 for (int i = 0; i < N; i++) {
719 float sum = x[i];
720
721 sum += (num0*mem0);
722 sum += (num1*mem1);
723 sum += (num2*mem2);
724 sum += (num3*mem3);
725 sum += (num4*mem4);
726 mem4 = mem3;
727 mem3 = mem2;
728 mem2 = mem1;
729 mem1 = mem0;
730 mem0 = x[i];
731 y[i] = sum;
732 }
733
734 mem[0] = mem0;
735 mem[1] = mem1;
736 mem[2] = mem2;
737 mem[3] = mem3;
738 mem[4] = mem4;
739 }
740
741 static void pitch_downsample(float *x[], float *x_lp,
742 int len, int C)
743 {
744 float ac[5];
745 float tmp=Q15ONE;
746 float lpc[4], mem[5]={0,0,0,0,0};
747 float lpc2[5];
748 float c1 = .8f;
749
750 for (int i = 1; i < len >> 1; i++)
751 x_lp[i] = .5f * (.5f * (x[0][(2*i-1)]+x[0][(2*i+1)])+x[0][2*i]);
752 x_lp[0] = .5f * (.5f * (x[0][1])+x[0][0]);
753 if (C==2) {
754 for (int i = 1; i < len >> 1; i++)
755 x_lp[i] += (.5f * (.5f * (x[1][(2*i-1)]+x[1][(2*i+1)])+x[1][2*i]));
756 x_lp[0] += .5f * (.5f * (x[1][1])+x[1][0]);
757 }
758
759 celt_autocorr(x_lp, ac, NULL, 0, 4, len>>1);
760
761 /* Noise floor -40 dB */
762 ac[0] *= 1.0001f;
763 /* Lag windowing */
764 for (int i = 1; i <= 4; i++) {
765 /*ac[i] *= exp(-.5*(2*M_PI*.002*i)*(2*M_PI*.002*i));*/
766 ac[i] -= ac[i]*(.008f*i)*(.008f*i);
767 }
768
769 celt_lpc(lpc, ac, 4);
770 for (int i = 0; i < 4; i++) {
771 tmp = .9f * tmp;
772 lpc[i] = (lpc[i] * tmp);
773 }
774 /* Add a zero */
775 lpc2[0] = lpc[0] + .8f;
776 lpc2[1] = lpc[1] + (c1 * lpc[0]);
777 lpc2[2] = lpc[2] + (c1 * lpc[1]);
778 lpc2[3] = lpc[3] + (c1 * lpc[2]);
779 lpc2[4] = (c1 * lpc[3]);
780 celt_fir5(x_lp, lpc2, x_lp, len>>1, mem);
781 }
782
783 static inline void dual_inner_prod(const float *x, const float *y01, const float *y02,
784 int N, float *xy1, float *xy2)
785 {
786 float xy01 = 0, xy02 = 0;
787
788 for (int i = 0; i < N; i++) {
789 xy01 += (x[i] * y01[i]);
790 xy02 += (x[i] * y02[i]);
791 }
792
793 *xy1 = xy01;
794 *xy2 = xy02;
795 }
796
797 static float compute_pitch_gain(float xy, float xx, float yy)
798 {
799 return xy / sqrtf(1.f + xx * yy);
800 }
801
802 static const uint8_t second_check[16] = {0, 0, 3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2};
803 static float remove_doubling(float *x, int maxperiod, int minperiod, int N,
804 int *T0_, int prev_period, float prev_gain)
805 {
806 int k, i, T, T0;
807 float g, g0;
808 float pg;
809 float xy,xx,yy,xy2;
810 float xcorr[3];
811 float best_xy, best_yy;
812 int offset;
813 int minperiod0;
814 float yy_lookup[PITCH_MAX_PERIOD+1];
815
816 minperiod0 = minperiod;
817 maxperiod /= 2;
818 minperiod /= 2;
819 *T0_ /= 2;
820 prev_period /= 2;
821 N /= 2;
822 x += maxperiod;
823 if (*T0_>=maxperiod)
824 *T0_=maxperiod-1;
825
826 T = T0 = *T0_;
827 dual_inner_prod(x, x, x-T0, N, &xx, &xy);
828 yy_lookup[0] = xx;
829 yy=xx;
830 for (i = 1; i <= maxperiod; i++) {
831 yy = yy+(x[-i] * x[-i])-(x[N-i] * x[N-i]);
832 yy_lookup[i] = FFMAX(0, yy);
833 }
834 yy = yy_lookup[T0];
835 best_xy = xy;
836 best_yy = yy;
837 g = g0 = compute_pitch_gain(xy, xx, yy);
838 /* Look for any pitch at T/k */
839 for (k = 2; k <= 15; k++) {
840 int T1, T1b;
841 float g1;
842 float cont=0;
843 float thresh;
844 T1 = (2*T0+k)/(2*k);
845 if (T1 < minperiod)
846 break;
847 /* Look for another strong correlation at T1b */
848 if (k==2)
849 {
850 if (T1+T0>maxperiod)
851 T1b = T0;
852 else
853 T1b = T0+T1;
854 } else
855 {
856 T1b = (2*second_check[k]*T0+k)/(2*k);
857 }
858 dual_inner_prod(x, &x[-T1], &x[-T1b], N, &xy, &xy2);
859 xy = .5f * (xy + xy2);
860 yy = .5f * (yy_lookup[T1] + yy_lookup[T1b]);
861 g1 = compute_pitch_gain(xy, xx, yy);
862 if (FFABS(T1-prev_period)<=1)
863 cont = prev_gain;
864 else if (FFABS(T1-prev_period)<=2 && 5 * k * k < T0)
865 cont = prev_gain * .5f;
866 else
867 cont = 0;
868 thresh = FFMAX(.3f, (.7f * g0) - cont);
869 /* Bias against very high pitch (very short period) to avoid false-positives
870 due to short-term correlation */
871 if (T1<3*minperiod)
872 thresh = FFMAX(.4f, (.85f * g0) - cont);
873 else if (T1<2*minperiod)
874 thresh = FFMAX(.5f, (.9f * g0) - cont);
875 if (g1 > thresh)
876 {
877 best_xy = xy;
878 best_yy = yy;
879 T = T1;
880 g = g1;
881 }
882 }
883 best_xy = FFMAX(0, best_xy);
884 if (best_yy <= best_xy)
885 pg = Q15ONE;
886 else
887 pg = best_xy/(best_yy + 1);
888
889 for (k = 0; k < 3; k++)
890 xcorr[k] = celt_inner_prod(x, x-(T+k-1), N);
891 if ((xcorr[2]-xcorr[0]) > .7f * (xcorr[1]-xcorr[0]))
892 offset = 1;
893 else if ((xcorr[0]-xcorr[2]) > (.7f * (xcorr[1] - xcorr[2])))
894 offset = -1;
895 else
896 offset = 0;
897 if (pg > g)
898 pg = g;
899 *T0_ = 2*T+offset;
900
901 if (*T0_<minperiod0)
902 *T0_=minperiod0;
903 return pg;
904 }
905
906 static void find_best_pitch(float *xcorr, float *y, int len,
907 int max_pitch, int *best_pitch)
908 {
909 float best_num[2];
910 float best_den[2];
911 float Syy = 1.f;
912
913 best_num[0] = -1;
914 best_num[1] = -1;
915 best_den[0] = 0;
916 best_den[1] = 0;
917 best_pitch[0] = 0;
918 best_pitch[1] = 1;
919
920 for (int j = 0; j < len; j++)
921 Syy += y[j] * y[j];
922
923 for (int i = 0; i < max_pitch; i++) {
924 if (xcorr[i]>0) {
925 float num;
926 float xcorr16;
927
928 xcorr16 = xcorr[i];
929 /* Considering the range of xcorr16, this should avoid both underflows
930 and overflows (inf) when squaring xcorr16 */
931 xcorr16 *= 1e-12f;
932 num = xcorr16 * xcorr16;
933 if ((num * best_den[1]) > (best_num[1] * Syy)) {
934 if ((num * best_den[0]) > (best_num[0] * Syy)) {
935 best_num[1] = best_num[0];
936 best_den[1] = best_den[0];
937 best_pitch[1] = best_pitch[0];
938 best_num[0] = num;
939 best_den[0] = Syy;
940 best_pitch[0] = i;
941 } else {
942 best_num[1] = num;
943 best_den[1] = Syy;
944 best_pitch[1] = i;
945 }
946 }
947 }
948 Syy += y[i+len]*y[i+len] - y[i] * y[i];
949 Syy = FFMAX(1, Syy);
950 }
951 }
952
953 static void pitch_search(const float *x_lp, float *y,
954 int len, int max_pitch, int *pitch)
955 {
956 int lag;
957 int best_pitch[2]={0,0};
958 int offset;
959
960 float x_lp4[WINDOW_SIZE];
961 float y_lp4[WINDOW_SIZE];
962 float xcorr[WINDOW_SIZE];
963
964 lag = len+max_pitch;
965
966 /* Downsample by 2 again */
967 for (int j = 0; j < len >> 2; j++)
968 x_lp4[j] = x_lp[2*j];
969 for (int j = 0; j < lag >> 2; j++)
970 y_lp4[j] = y[2*j];
971
972 /* Coarse search with 4x decimation */
973
974 celt_pitch_xcorr(x_lp4, y_lp4, xcorr, len>>2, max_pitch>>2);
975
976 find_best_pitch(xcorr, y_lp4, len>>2, max_pitch>>2, best_pitch);
977
978 /* Finer search with 2x decimation */
979 for (int i = 0; i < max_pitch >> 1; i++) {
980 float sum;
981 xcorr[i] = 0;
982 if (FFABS(i-2*best_pitch[0])>2 && FFABS(i-2*best_pitch[1])>2)
983 continue;
984 sum = celt_inner_prod(x_lp, y+i, len>>1);
985 xcorr[i] = FFMAX(-1, sum);
986 }
987
988 find_best_pitch(xcorr, y, len>>1, max_pitch>>1, best_pitch);
989
990 /* Refine by pseudo-interpolation */
991 if (best_pitch[0] > 0 && best_pitch[0] < (max_pitch >> 1) - 1) {
992 float a, b, c;
993
994 a = xcorr[best_pitch[0] - 1];
995 b = xcorr[best_pitch[0]];
996 c = xcorr[best_pitch[0] + 1];
997 if (c - a > .7f * (b - a))
998 offset = 1;
999 else if (a - c > .7f * (b-c))
1000 offset = -1;
1001 else
1002 offset = 0;
1003 } else {
1004 offset = 0;
1005 }
1006
1007 *pitch = 2 * best_pitch[0] - offset;
1008 }
1009
1010 static void dct(AudioRNNContext *s, float *out, const float *in)
1011 {
1012 for (int i = 0; i < NB_BANDS; i++) {
1013 float sum;
1014
1015 sum = s->fdsp->scalarproduct_float(in, s->dct_table[i], FFALIGN(NB_BANDS, 4));
1016 out[i] = sum * sqrtf(2.f / 22);
1017 }
1018 }
1019
1020 static int compute_frame_features(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, AVComplexFloat *P,
1021 float *Ex, float *Ep, float *Exp, float *features, const float *in)
1022 {
1023 float E = 0;
1024 float *ceps_0, *ceps_1, *ceps_2;
1025 float spec_variability = 0;
1026 LOCAL_ALIGNED_32(float, Ly, [NB_BANDS]);
1027 LOCAL_ALIGNED_32(float, p, [WINDOW_SIZE]);
1028 float pitch_buf[PITCH_BUF_SIZE>>1];
1029 int pitch_index;
1030 float gain;
1031 float *(pre[1]);
1032 float tmp[NB_BANDS];
1033 float follow, logMax;
1034
1035 frame_analysis(s, st, X, Ex, in);
1036 RNN_MOVE(st->pitch_buf, &st->pitch_buf[FRAME_SIZE], PITCH_BUF_SIZE-FRAME_SIZE);
1037 RNN_COPY(&st->pitch_buf[PITCH_BUF_SIZE-FRAME_SIZE], in, FRAME_SIZE);
1038 pre[0] = &st->pitch_buf[0];
1039 pitch_downsample(pre, pitch_buf, PITCH_BUF_SIZE, 1);
1040 pitch_search(pitch_buf+(PITCH_MAX_PERIOD>>1), pitch_buf, PITCH_FRAME_SIZE,
1041 PITCH_MAX_PERIOD-3*PITCH_MIN_PERIOD, &pitch_index);
1042 pitch_index = PITCH_MAX_PERIOD-pitch_index;
1043
1044 gain = remove_doubling(pitch_buf, PITCH_MAX_PERIOD, PITCH_MIN_PERIOD,
1045 PITCH_FRAME_SIZE, &pitch_index, st->last_period, st->last_gain);
1046 st->last_period = pitch_index;
1047 st->last_gain = gain;
1048
1049 for (int i = 0; i < WINDOW_SIZE; i++)
1050 p[i] = st->pitch_buf[PITCH_BUF_SIZE-WINDOW_SIZE-pitch_index+i];
1051
1052 s->fdsp->vector_fmul(p, p, s->window, WINDOW_SIZE);
1053 forward_transform(st, P, p);
1054 compute_band_energy(Ep, P);
1055 compute_band_corr(Exp, X, P);
1056
1057 for (int i = 0; i < NB_BANDS; i++)
1058 Exp[i] = Exp[i] / sqrtf(.001f+Ex[i]*Ep[i]);
1059
1060 dct(s, tmp, Exp);
1061
1062 for (int i = 0; i < NB_DELTA_CEPS; i++)
1063 features[NB_BANDS+2*NB_DELTA_CEPS+i] = tmp[i];
1064
1065 features[NB_BANDS+2*NB_DELTA_CEPS] -= 1.3;
1066 features[NB_BANDS+2*NB_DELTA_CEPS+1] -= 0.9;
1067 features[NB_BANDS+3*NB_DELTA_CEPS] = .01*(pitch_index-300);
1068 logMax = -2;
1069 follow = -2;
1070
1071 for (int i = 0; i < NB_BANDS; i++) {
1072 Ly[i] = log10f(1e-2f + Ex[i]);
1073 Ly[i] = FFMAX(logMax-7, FFMAX(follow-1.5, Ly[i]));
1074 logMax = FFMAX(logMax, Ly[i]);
1075 follow = FFMAX(follow-1.5, Ly[i]);
1076 E += Ex[i];
1077 }
1078
1079 if (E < 0.04f) {
1080 /* If there's no audio, avoid messing up the state. */
1081 RNN_CLEAR(features, NB_FEATURES);
1082 return 1;
1083 }
1084
1085 dct(s, features, Ly);
1086 features[0] -= 12;
1087 features[1] -= 4;
1088 ceps_0 = st->cepstral_mem[st->memid];
1089 ceps_1 = (st->memid < 1) ? st->cepstral_mem[CEPS_MEM+st->memid-1] : st->cepstral_mem[st->memid-1];
1090 ceps_2 = (st->memid < 2) ? st->cepstral_mem[CEPS_MEM+st->memid-2] : st->cepstral_mem[st->memid-2];
1091
1092 for (int i = 0; i < NB_BANDS; i++)
1093 ceps_0[i] = features[i];
1094
1095 st->memid++;
1096 for (int i = 0; i < NB_DELTA_CEPS; i++) {
1097 features[i] = ceps_0[i] + ceps_1[i] + ceps_2[i];
1098 features[NB_BANDS+i] = ceps_0[i] - ceps_2[i];
1099 features[NB_BANDS+NB_DELTA_CEPS+i] = ceps_0[i] - 2*ceps_1[i] + ceps_2[i];
1100 }
1101 /* Spectral variability features. */
1102 if (st->memid == CEPS_MEM)
1103 st->memid = 0;
1104
1105 for (int i = 0; i < CEPS_MEM; i++) {
1106 float mindist = 1e15f;
1107 for (int j = 0; j < CEPS_MEM; j++) {
1108 float dist = 0.f;
1109 for (int k = 0; k < NB_BANDS; k++) {
1110 float tmp;
1111
1112 tmp = st->cepstral_mem[i][k] - st->cepstral_mem[j][k];
1113 dist += tmp*tmp;
1114 }
1115
1116 if (j != i)
1117 mindist = FFMIN(mindist, dist);
1118 }
1119
1120 spec_variability += mindist;
1121 }
1122
1123 features[NB_BANDS+3*NB_DELTA_CEPS+1] = spec_variability/CEPS_MEM-2.1;
1124
1125 return 0;
1126 }
1127
1128 static void interp_band_gain(float *g, const float *bandE)
1129 {
1130 memset(g, 0, sizeof(*g) * FREQ_SIZE);
1131
1132 for (int i = 0; i < NB_BANDS - 1; i++) {
1133 const int band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
1134
1135 for (int j = 0; j < band_size; j++) {
1136 float frac = (float)j / band_size;
1137
1138 g[(eband5ms[i] << FRAME_SIZE_SHIFT) + j] = (1.f - frac) * bandE[i] + frac * bandE[i + 1];
1139 }
1140 }
1141 }
1142
1143 static void pitch_filter(AVComplexFloat *X, const AVComplexFloat *P, const float *Ex, const float *Ep,
1144 const float *Exp, const float *g)
1145 {
1146 float newE[NB_BANDS];
1147 float r[NB_BANDS];
1148 float norm[NB_BANDS];
1149 float rf[FREQ_SIZE] = {0};
1150 float normf[FREQ_SIZE]={0};
1151
1152 for (int i = 0; i < NB_BANDS; i++) {
1153 if (Exp[i]>g[i]) r[i] = 1;
1154 else r[i] = SQUARE(Exp[i])*(1-SQUARE(g[i]))/(.001 + SQUARE(g[i])*(1-SQUARE(Exp[i])));
1155 r[i] = sqrtf(av_clipf(r[i], 0, 1));
1156 r[i] *= sqrtf(Ex[i]/(1e-8+Ep[i]));
1157 }
1158 interp_band_gain(rf, r);
1159 for (int i = 0; i < FREQ_SIZE; i++) {
1160 X[i].re += rf[i]*P[i].re;
1161 X[i].im += rf[i]*P[i].im;
1162 }
1163 compute_band_energy(newE, X);
1164 for (int i = 0; i < NB_BANDS; i++) {
1165 norm[i] = sqrtf(Ex[i] / (1e-8+newE[i]));
1166 }
1167 interp_band_gain(normf, norm);
1168 for (int i = 0; i < FREQ_SIZE; i++) {
1169 X[i].re *= normf[i];
1170 X[i].im *= normf[i];
1171 }
1172 }
1173
1174 static const float tansig_table[201] = {
1175 0.000000f, 0.039979f, 0.079830f, 0.119427f, 0.158649f,
1176 0.197375f, 0.235496f, 0.272905f, 0.309507f, 0.345214f,
1177 0.379949f, 0.413644f, 0.446244f, 0.477700f, 0.507977f,
1178 0.537050f, 0.564900f, 0.591519f, 0.616909f, 0.641077f,
1179 0.664037f, 0.685809f, 0.706419f, 0.725897f, 0.744277f,
1180 0.761594f, 0.777888f, 0.793199f, 0.807569f, 0.821040f,
1181 0.833655f, 0.845456f, 0.856485f, 0.866784f, 0.876393f,
1182 0.885352f, 0.893698f, 0.901468f, 0.908698f, 0.915420f,
1183 0.921669f, 0.927473f, 0.932862f, 0.937863f, 0.942503f,
1184 0.946806f, 0.950795f, 0.954492f, 0.957917f, 0.961090f,
1185 0.964028f, 0.966747f, 0.969265f, 0.971594f, 0.973749f,
1186 0.975743f, 0.977587f, 0.979293f, 0.980869f, 0.982327f,
1187 0.983675f, 0.984921f, 0.986072f, 0.987136f, 0.988119f,
1188 0.989027f, 0.989867f, 0.990642f, 0.991359f, 0.992020f,
1189 0.992631f, 0.993196f, 0.993718f, 0.994199f, 0.994644f,
1190 0.995055f, 0.995434f, 0.995784f, 0.996108f, 0.996407f,
1191 0.996682f, 0.996937f, 0.997172f, 0.997389f, 0.997590f,
1192 0.997775f, 0.997946f, 0.998104f, 0.998249f, 0.998384f,
1193 0.998508f, 0.998623f, 0.998728f, 0.998826f, 0.998916f,
1194 0.999000f, 0.999076f, 0.999147f, 0.999213f, 0.999273f,
1195 0.999329f, 0.999381f, 0.999428f, 0.999472f, 0.999513f,
1196 0.999550f, 0.999585f, 0.999617f, 0.999646f, 0.999673f,
1197 0.999699f, 0.999722f, 0.999743f, 0.999763f, 0.999781f,
1198 0.999798f, 0.999813f, 0.999828f, 0.999841f, 0.999853f,
1199 0.999865f, 0.999875f, 0.999885f, 0.999893f, 0.999902f,
1200 0.999909f, 0.999916f, 0.999923f, 0.999929f, 0.999934f,
1201 0.999939f, 0.999944f, 0.999948f, 0.999952f, 0.999956f,
1202 0.999959f, 0.999962f, 0.999965f, 0.999968f, 0.999970f,
1203 0.999973f, 0.999975f, 0.999977f, 0.999978f, 0.999980f,
1204 0.999982f, 0.999983f, 0.999984f, 0.999986f, 0.999987f,
1205 0.999988f, 0.999989f, 0.999990f, 0.999990f, 0.999991f,
1206 0.999992f, 0.999992f, 0.999993f, 0.999994f, 0.999994f,
1207 0.999994f, 0.999995f, 0.999995f, 0.999996f, 0.999996f,
1208 0.999996f, 0.999997f, 0.999997f, 0.999997f, 0.999997f,
1209 0.999997f, 0.999998f, 0.999998f, 0.999998f, 0.999998f,
1210 0.999998f, 0.999998f, 0.999999f, 0.999999f, 0.999999f,
1211 0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1212 0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1213 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1214 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1215 1.000000f,
1216 };
1217
1218 static inline float tansig_approx(float x)
1219 {
1220 float y, dy;
1221 float sign=1;
1222 int i;
1223
1224 /* Tests are reversed to catch NaNs */
1225 if (!(x<8))
1226 return 1;
1227 if (!(x>-8))
1228 return -1;
1229 /* Another check in case of -ffast-math */
1230
1231 if (isnan(x))
1232 return 0;
1233
1234 if (x < 0) {
1235 x=-x;
1236 sign=-1;
1237 }
1238 i = (int)floor(.5f+25*x);
1239 x -= .04f*i;
1240 y = tansig_table[i];
1241 dy = 1-y*y;
1242 y = y + x*dy*(1 - y*x);
1243 return sign*y;
1244 }
1245
1246 static inline float sigmoid_approx(float x)
1247 {
1248 return .5f + .5f*tansig_approx(.5f*x);
1249 }
1250
1251 static void compute_dense(const DenseLayer *layer, float *output, const float *input)
1252 {
1253 const int N = layer->nb_neurons, M = layer->nb_inputs, stride = N;
1254
1255 for (int i = 0; i < N; i++) {
1256 /* Compute update gate. */
1257 float sum = layer->bias[i];
1258
1259 for (int j = 0; j < M; j++)
1260 sum += layer->input_weights[j * stride + i] * input[j];
1261
1262 output[i] = WEIGHTS_SCALE * sum;
1263 }
1264
1265 if (layer->activation == ACTIVATION_SIGMOID) {
1266 for (int i = 0; i < N; i++)
1267 output[i] = sigmoid_approx(output[i]);
1268 } else if (layer->activation == ACTIVATION_TANH) {
1269 for (int i = 0; i < N; i++)
1270 output[i] = tansig_approx(output[i]);
1271 } else if (layer->activation == ACTIVATION_RELU) {
1272 for (int i = 0; i < N; i++)
1273 output[i] = FFMAX(0, output[i]);
1274 } else {
1275 av_assert0(0);
1276 }
1277 }
1278
1279 static void compute_gru(AudioRNNContext *s, const GRULayer *gru, float *state, const float *input)
1280 {
1281 LOCAL_ALIGNED_32(float, z, [MAX_NEURONS]);
1282 LOCAL_ALIGNED_32(float, r, [MAX_NEURONS]);
1283 LOCAL_ALIGNED_32(float, h, [MAX_NEURONS]);
1284 const int M = gru->nb_inputs;
1285 const int N = gru->nb_neurons;
1286 const int AN = FFALIGN(N, 4);
1287 const int AM = FFALIGN(M, 4);
1288 const int stride = 3 * AN, istride = 3 * AM;
1289
1290 for (int i = 0; i < N; i++) {
1291 /* Compute update gate. */
1292 float sum = gru->bias[i];
1293
1294 sum += s->fdsp->scalarproduct_float(gru->input_weights + i * istride, input, AM);
1295 sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + i * stride, state, AN);
1296 z[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1297 }
1298
1299 for (int i = 0; i < N; i++) {
1300 /* Compute reset gate. */
1301 float sum = gru->bias[N + i];
1302
1303 sum += s->fdsp->scalarproduct_float(gru->input_weights + AM + i * istride, input, AM);
1304 sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + AN + i * stride, state, AN);
1305 r[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1306 }
1307
1308 for (int i = 0; i < N; i++) {
1309 /* Compute output. */
1310 float sum = gru->bias[2 * N + i];
1311
1312 sum += s->fdsp->scalarproduct_float(gru->input_weights + 2 * AM + i * istride, input, AM);
1313 for (int j = 0; j < N; j++)
1314 sum += gru->recurrent_weights[2 * AN + i * stride + j] * state[j] * r[j];
1315
1316 if (gru->activation == ACTIVATION_SIGMOID)
1317 sum = sigmoid_approx(WEIGHTS_SCALE * sum);
1318 else if (gru->activation == ACTIVATION_TANH)
1319 sum = tansig_approx(WEIGHTS_SCALE * sum);
1320 else if (gru->activation == ACTIVATION_RELU)
1321 sum = FFMAX(0, WEIGHTS_SCALE * sum);
1322 else
1323 av_assert0(0);
1324 h[i] = z[i] * state[i] + (1.f - z[i]) * sum;
1325 }
1326
1327 RNN_COPY(state, h, N);
1328 }
1329
1330 #define INPUT_SIZE 42
1331
1332 static void compute_rnn(AudioRNNContext *s, RNNState *rnn, float *gains, float *vad, const float *input)
1333 {
1334 LOCAL_ALIGNED_32(float, dense_out, [MAX_NEURONS]);
1335 LOCAL_ALIGNED_32(float, noise_input, [MAX_NEURONS * 3]);
1336 LOCAL_ALIGNED_32(float, denoise_input, [MAX_NEURONS * 3]);
1337
1338 compute_dense(rnn->model->input_dense, dense_out, input);
1339 compute_gru(s, rnn->model->vad_gru, rnn->vad_gru_state, dense_out);
1340 compute_dense(rnn->model->vad_output, vad, rnn->vad_gru_state);
1341
1342 memcpy(noise_input, dense_out, rnn->model->input_dense_size * sizeof(float));
1343 memcpy(noise_input + rnn->model->input_dense_size,
1344 rnn->vad_gru_state, rnn->model->vad_gru_size * sizeof(float));
1345 memcpy(noise_input + rnn->model->input_dense_size + rnn->model->vad_gru_size,
1346 input, INPUT_SIZE * sizeof(float));
1347
1348 compute_gru(s, rnn->model->noise_gru, rnn->noise_gru_state, noise_input);
1349
1350 memcpy(denoise_input, rnn->vad_gru_state, rnn->model->vad_gru_size * sizeof(float));
1351 memcpy(denoise_input + rnn->model->vad_gru_size,
1352 rnn->noise_gru_state, rnn->model->noise_gru_size * sizeof(float));
1353 memcpy(denoise_input + rnn->model->vad_gru_size + rnn->model->noise_gru_size,
1354 input, INPUT_SIZE * sizeof(float));
1355
1356 compute_gru(s, rnn->model->denoise_gru, rnn->denoise_gru_state, denoise_input);
1357 compute_dense(rnn->model->denoise_output, gains, rnn->denoise_gru_state);
1358 }
1359
1360 static float rnnoise_channel(AudioRNNContext *s, DenoiseState *st, float *out, const float *in,
1361 int disabled)
1362 {
1363 AVComplexFloat X[FREQ_SIZE];
1364 AVComplexFloat P[WINDOW_SIZE];
1365 float x[FRAME_SIZE];
1366 float Ex[NB_BANDS], Ep[NB_BANDS];
1367 LOCAL_ALIGNED_32(float, Exp, [NB_BANDS]);
1368 float features[NB_FEATURES];
1369 float g[NB_BANDS];
1370 float gf[FREQ_SIZE];
1371 float vad_prob = 0;
1372 float *history = st->history;
1373 static const float a_hp[2] = {-1.99599, 0.99600};
1374 static const float b_hp[2] = {-2, 1};
1375 int silence;
1376
1377 biquad(x, st->mem_hp_x, in, b_hp, a_hp, FRAME_SIZE);
1378 silence = compute_frame_features(s, st, X, P, Ex, Ep, Exp, features, x);
1379
1380 if (!silence && !disabled) {
1381 compute_rnn(s, &st->rnn[0], g, &vad_prob, features);
1382 pitch_filter(X, P, Ex, Ep, Exp, g);
1383 for (int i = 0; i < NB_BANDS; i++) {
1384 float alpha = .6f;
1385
1386 g[i] = FFMAX(g[i], alpha * st->lastg[i]);
1387 st->lastg[i] = g[i];
1388 }
1389
1390 interp_band_gain(gf, g);
1391
1392 for (int i = 0; i < FREQ_SIZE; i++) {
1393 X[i].re *= gf[i];
1394 X[i].im *= gf[i];
1395 }
1396 }
1397
1398 frame_synthesis(s, st, out, X);
1399 memcpy(history, in, FRAME_SIZE * sizeof(*history));
1400
1401 return vad_prob;
1402 }
1403
1404 typedef struct ThreadData {
1405 AVFrame *in, *out;
1406 } ThreadData;
1407
1408 static int rnnoise_channels(AVFilterContext *ctx, void *arg, int jobnr, int nb_jobs)
1409 {
1410 AudioRNNContext *s = ctx->priv;
1411 ThreadData *td = arg;
1412 AVFrame *in = td->in;
1413 AVFrame *out = td->out;
1414 const int start = (out->ch_layout.nb_channels * jobnr) / nb_jobs;
1415 const int end = (out->ch_layout.nb_channels * (jobnr+1)) / nb_jobs;
1416
1417 for (int ch = start; ch < end; ch++) {
1418 rnnoise_channel(s, &s->st[ch],
1419 (float *)out->extended_data[ch],
1420 (const float *)in->extended_data[ch],
1421 ctx->is_disabled);
1422 }
1423
1424 return 0;
1425 }
1426
1427 static int filter_frame(AVFilterLink *inlink, AVFrame *in)
1428 {
1429 AVFilterContext *ctx = inlink->dst;
1430 AVFilterLink *outlink = ctx->outputs[0];
1431 AVFrame *out = NULL;
1432 ThreadData td;
1433
1434 out = ff_get_audio_buffer(outlink, FRAME_SIZE);
1435 if (!out) {
1436 av_frame_free(&in);
1437 return AVERROR(ENOMEM);
1438 }
1439 out->pts = in->pts;
1440
1441 td.in = in; td.out = out;
1442 ff_filter_execute(ctx, rnnoise_channels, &td, NULL,
1443 FFMIN(outlink->ch_layout.nb_channels, ff_filter_get_nb_threads(ctx)));
1444
1445 av_frame_free(&in);
1446 return ff_filter_frame(outlink, out);
1447 }
1448
1449 static int activate(AVFilterContext *ctx)
1450 {
1451 AVFilterLink *inlink = ctx->inputs[0];
1452 AVFilterLink *outlink = ctx->outputs[0];
1453 AVFrame *in = NULL;
1454 int ret;
1455
1456 FF_FILTER_FORWARD_STATUS_BACK(outlink, inlink);
1457
1458 ret = ff_inlink_consume_samples(inlink, FRAME_SIZE, FRAME_SIZE, &in);
1459 if (ret < 0)
1460 return ret;
1461
1462 if (ret > 0)
1463 return filter_frame(inlink, in);
1464
1465 FF_FILTER_FORWARD_STATUS(inlink, outlink);
1466 FF_FILTER_FORWARD_WANTED(outlink, inlink);
1467
1468 return FFERROR_NOT_READY;
1469 }
1470
1471 static int open_model(AVFilterContext *ctx, RNNModel **model)
1472 {
1473 AudioRNNContext *s = ctx->priv;
1474 int ret;
1475 FILE *f;
1476
1477 if (!s->model_name)
1478 return AVERROR(EINVAL);
1479 f = avpriv_fopen_utf8(s->model_name, "r");
1480 if (!f) {
1481 av_log(ctx, AV_LOG_ERROR, "Failed to open model file: %s\n", s->model_name);
1482 return AVERROR(EINVAL);
1483 }
1484
1485 ret = rnnoise_model_from_file(f, model);
1486 fclose(f);
1487 if (!*model || ret < 0)
1488 return ret;
1489
1490 return 0;
1491 }
1492
1493 static av_cold int init(AVFilterContext *ctx)
1494 {
1495 AudioRNNContext *s = ctx->priv;
1496 int ret;
1497
1498 s->fdsp = avpriv_float_dsp_alloc(0);
1499 if (!s->fdsp)
1500 return AVERROR(ENOMEM);
1501
1502 ret = open_model(ctx, &s->model[0]);
1503 if (ret < 0)
1504 return ret;
1505
1506 for (int i = 0; i < FRAME_SIZE; i++) {
1507 s->window[i] = sin(.5*M_PI*sin(.5*M_PI*(i+.5)/FRAME_SIZE) * sin(.5*M_PI*(i+.5)/FRAME_SIZE));
1508 s->window[WINDOW_SIZE - 1 - i] = s->window[i];
1509 }
1510
1511 for (int i = 0; i < NB_BANDS; i++) {
1512 for (int j = 0; j < NB_BANDS; j++) {
1513 s->dct_table[j][i] = cosf((i + .5f) * j * M_PI / NB_BANDS);
1514 if (j == 0)
1515 s->dct_table[j][i] *= sqrtf(.5);
1516 }
1517 }
1518
1519 return 0;
1520 }
1521
1522 static void free_model(AVFilterContext *ctx, int n)
1523 {
1524 AudioRNNContext *s = ctx->priv;
1525
1526 rnnoise_model_free(s->model[n]);
1527 s->model[n] = NULL;
1528
1529 for (int ch = 0; ch < s->channels && s->st; ch++) {
1530 av_freep(&s->st[ch].rnn[n].vad_gru_state);
1531 av_freep(&s->st[ch].rnn[n].noise_gru_state);
1532 av_freep(&s->st[ch].rnn[n].denoise_gru_state);
1533 }
1534 }
1535
1536 static int process_command(AVFilterContext *ctx, const char *cmd, const char *args,
1537 char *res, int res_len, int flags)
1538 {
1539 AudioRNNContext *s = ctx->priv;
1540 int ret;
1541
1542 ret = ff_filter_process_command(ctx, cmd, args, res, res_len, flags);
1543 if (ret < 0)
1544 return ret;
1545
1546 ret = open_model(ctx, &s->model[1]);
1547 if (ret < 0)
1548 return ret;
1549
1550 FFSWAP(RNNModel *, s->model[0], s->model[1]);
1551 for (int ch = 0; ch < s->channels; ch++)
1552 FFSWAP(RNNState, s->st[ch].rnn[0], s->st[ch].rnn[1]);
1553
1554 ret = config_input(ctx->inputs[0]);
1555 if (ret < 0) {
1556 for (int ch = 0; ch < s->channels; ch++)
1557 FFSWAP(RNNState, s->st[ch].rnn[0], s->st[ch].rnn[1]);
1558 FFSWAP(RNNModel *, s->model[0], s->model[1]);
1559 return ret;
1560 }
1561
1562 free_model(ctx, 1);
1563 return 0;
1564 }
1565
1566 static av_cold void uninit(AVFilterContext *ctx)
1567 {
1568 AudioRNNContext *s = ctx->priv;
1569
1570 av_freep(&s->fdsp);
1571 free_model(ctx, 0);
1572 for (int ch = 0; ch < s->channels && s->st; ch++) {
1573 av_tx_uninit(&s->st[ch].tx);
1574 av_tx_uninit(&s->st[ch].txi);
1575 }
1576 av_freep(&s->st);
1577 }
1578
1579 static const AVFilterPad inputs[] = {
1580 {
1581 .name = "default",
1582 .type = AVMEDIA_TYPE_AUDIO,
1583 .config_props = config_input,
1584 },
1585 };
1586
1587 static const AVFilterPad outputs[] = {
1588 {
1589 .name = "default",
1590 .type = AVMEDIA_TYPE_AUDIO,
1591 },
1592 };
1593
1594 #define OFFSET(x) offsetof(AudioRNNContext, x)
1595 #define AF AV_OPT_FLAG_AUDIO_PARAM|AV_OPT_FLAG_FILTERING_PARAM|AV_OPT_FLAG_RUNTIME_PARAM
1596
1597 static const AVOption arnndn_options[] = {
1598 { "model", "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1599 { "m", "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1600 { "mix", "set output vs input mix", OFFSET(mix), AV_OPT_TYPE_FLOAT, {.dbl=1.0},-1, 1, AF },
1601 { NULL }
1602 };
1603
1604 AVFILTER_DEFINE_CLASS(arnndn);
1605
1606 const AVFilter ff_af_arnndn = {
1607 .name = "arnndn",
1608 .description = NULL_IF_CONFIG_SMALL("Reduce noise from speech using Recurrent Neural Networks."),
1609 .priv_size = sizeof(AudioRNNContext),
1610 .priv_class = &arnndn_class,
1611 .activate = activate,
1612 .init = init,
1613 .uninit = uninit,
1614 FILTER_INPUTS(inputs),
1615 FILTER_OUTPUTS(outputs),
1616 FILTER_QUERY_FUNC(query_formats),
1617 .flags = AVFILTER_FLAG_SUPPORT_TIMELINE_INTERNAL |
1618 AVFILTER_FLAG_SLICE_THREADS,
1619 .process_command = process_command,
1620 };
1621