LCOV - code coverage report
Current view: top level - libavfilter - vf_srcnn.c (source / functions) Hit Total Coverage
Test: coverage.info Lines: 0 191 0.0 %
Date: 2018-05-20 11:54:08 Functions: 0 10 0.0 %

          Line data    Source code
       1             : /*
       2             :  * Copyright (c) 2018 Sergey Lavrushkin
       3             :  *
       4             :  * This file is part of FFmpeg.
       5             :  *
       6             :  * FFmpeg is free software; you can redistribute it and/or
       7             :  * modify it under the terms of the GNU Lesser General Public
       8             :  * License as published by the Free Software Foundation; either
       9             :  * version 2.1 of the License, or (at your option) any later version.
      10             :  *
      11             :  * FFmpeg is distributed in the hope that it will be useful,
      12             :  * but WITHOUT ANY WARRANTY; without even the implied warranty of
      13             :  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
      14             :  * Lesser General Public License for more details.
      15             :  *
      16             :  * You should have received a copy of the GNU Lesser General Public
      17             :  * License along with FFmpeg; if not, write to the Free Software
      18             :  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
      19             :  */
      20             : 
      21             : /**
      22             :  * @file
      23             :  * Filter implementing image super-resolution using deep convolutional networks.
      24             :  * https://arxiv.org/abs/1501.00092
      25             :  */
      26             : 
      27             : #include "avfilter.h"
      28             : #include "formats.h"
      29             : #include "internal.h"
      30             : #include "libavutil/opt.h"
      31             : #include "vf_srcnn.h"
      32             : #include "libavformat/avio.h"
      33             : 
      34             : typedef struct Convolution
      35             : {
      36             :     double* kernel;
      37             :     double* biases;
      38             :     int32_t size, input_channels, output_channels;
      39             : } Convolution;
      40             : 
      41             : typedef struct SRCNNContext {
      42             :     const AVClass *class;
      43             : 
      44             :     /// SRCNN convolutions
      45             :     struct Convolution conv1, conv2, conv3;
      46             :     /// Path to binary file with kernels specifications
      47             :     char* config_file_path;
      48             :     /// Buffers for network input/output and feature maps
      49             :     double* input_output_buf;
      50             :     double* conv1_buf;
      51             :     double* conv2_buf;
      52             : } SRCNNContext;
      53             : 
      54             : 
      55             : #define OFFSET(x) offsetof(SRCNNContext, x)
      56             : #define FLAGS AV_OPT_FLAG_FILTERING_PARAM | AV_OPT_FLAG_VIDEO_PARAM
      57             : static const AVOption srcnn_options[] = {
      58             :     { "config_file", "path to configuration file with network parameters", OFFSET(config_file_path), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, FLAGS },
      59             :     { NULL }
      60             : };
      61             : 
      62             : AVFILTER_DEFINE_CLASS(srcnn);
      63             : 
      64             : #define CHECK_FILE_SIZE(file_size, srcnn_size, avio_context)    if (srcnn_size > file_size){ \
      65             :                                                                     av_log(context, AV_LOG_ERROR, "error reading configuration file\n");\
      66             :                                                                     avio_closep(&avio_context); \
      67             :                                                                     return AVERROR(EIO); \
      68             :                                                                 }
      69             : 
      70             : #define CHECK_ALLOCATION(call, end_call)    if (call){ \
      71             :                                                 av_log(context, AV_LOG_ERROR, "could not allocate memory for convolutions\n"); \
      72             :                                                 end_call; \
      73             :                                                 return AVERROR(ENOMEM); \
      74             :                                             }
      75             : 
      76           0 : static int allocate_read_conv_data(Convolution* conv, AVIOContext* config_file_context)
      77             : {
      78           0 :     int32_t kernel_size = conv->output_channels * conv->size * conv->size * conv->input_channels;
      79             :     int32_t i;
      80             : 
      81           0 :     conv->kernel = av_malloc(kernel_size * sizeof(double));
      82           0 :     if (!conv->kernel){
      83           0 :         return AVERROR(ENOMEM);
      84             :     }
      85           0 :     for (i = 0; i < kernel_size; ++i){
      86           0 :         conv->kernel[i] = av_int2double(avio_rl64(config_file_context));
      87             :     }
      88             : 
      89           0 :     conv->biases = av_malloc(conv->output_channels * sizeof(double));
      90           0 :     if (!conv->biases){
      91           0 :         return AVERROR(ENOMEM);
      92             :     }
      93           0 :     for (i = 0; i < conv->output_channels; ++i){
      94           0 :         conv->biases[i] = av_int2double(avio_rl64(config_file_context));
      95             :     }
      96             : 
      97           0 :     return 0;
      98             : }
      99             : 
     100           0 : static int allocate_copy_conv_data(Convolution* conv, const double* kernel, const double* biases)
     101             : {
     102           0 :     int32_t kernel_size = conv->output_channels * conv->size * conv->size * conv->input_channels;
     103             : 
     104           0 :     conv->kernel = av_malloc(kernel_size * sizeof(double));
     105           0 :     if (!conv->kernel){
     106           0 :         return AVERROR(ENOMEM);
     107             :     }
     108           0 :     memcpy(conv->kernel, kernel, kernel_size * sizeof(double));
     109             : 
     110           0 :     conv->biases = av_malloc(conv->output_channels * sizeof(double));
     111           0 :     if (!conv->kernel){
     112           0 :         return AVERROR(ENOMEM);
     113             :     }
     114           0 :     memcpy(conv->biases, biases, conv->output_channels * sizeof(double));
     115             : 
     116           0 :     return 0;
     117             : }
     118             : 
     119           0 : static av_cold int init(AVFilterContext* context)
     120             : {
     121           0 :     SRCNNContext *srcnn_context = context->priv;
     122             :     AVIOContext* config_file_context;
     123             :     int64_t file_size, srcnn_size;
     124             : 
     125             :     /// Check specified confguration file name and read network weights from it
     126           0 :     if (!srcnn_context->config_file_path){
     127           0 :         av_log(context, AV_LOG_INFO, "configuration file for network was not specified, using default weights for x2 upsampling\n");
     128             : 
     129             :         /// Create convolution kernels and copy default weights
     130           0 :         srcnn_context->conv1.input_channels = 1;
     131           0 :         srcnn_context->conv1.output_channels = 64;
     132           0 :         srcnn_context->conv1.size = 9;
     133           0 :         CHECK_ALLOCATION(allocate_copy_conv_data(&srcnn_context->conv1, conv1_kernel, conv1_biases), )
     134             : 
     135           0 :         srcnn_context->conv2.input_channels = 64;
     136           0 :         srcnn_context->conv2.output_channels = 32;
     137           0 :         srcnn_context->conv2.size = 1;
     138           0 :         CHECK_ALLOCATION(allocate_copy_conv_data(&srcnn_context->conv2, conv2_kernel, conv2_biases), )
     139             : 
     140           0 :         srcnn_context->conv3.input_channels = 32;
     141           0 :         srcnn_context->conv3.output_channels = 1;
     142           0 :         srcnn_context->conv3.size = 5;
     143           0 :         CHECK_ALLOCATION(allocate_copy_conv_data(&srcnn_context->conv3, conv3_kernel, conv3_biases), )
     144             :     }
     145           0 :     else if (avio_check(srcnn_context->config_file_path, AVIO_FLAG_READ) > 0){
     146           0 :         if (avio_open(&config_file_context, srcnn_context->config_file_path, AVIO_FLAG_READ) < 0){
     147           0 :             av_log(context, AV_LOG_ERROR, "failed to open configuration file\n");
     148           0 :             return AVERROR(EIO);
     149             :         }
     150             : 
     151           0 :         file_size = avio_size(config_file_context);
     152             : 
     153             :         /// Create convolution kernels and read weights from file
     154           0 :         srcnn_context->conv1.input_channels = 1;
     155           0 :         srcnn_context->conv1.size = (int32_t)avio_rl32(config_file_context);
     156           0 :         srcnn_context->conv1.output_channels = (int32_t)avio_rl32(config_file_context);
     157           0 :         srcnn_size = 8 + (srcnn_context->conv1.output_channels * srcnn_context->conv1.size *
     158           0 :                           srcnn_context->conv1.size * srcnn_context->conv1.input_channels +
     159           0 :                           srcnn_context->conv1.output_channels << 3);
     160           0 :         CHECK_FILE_SIZE(file_size, srcnn_size, config_file_context)
     161           0 :         CHECK_ALLOCATION(allocate_read_conv_data(&srcnn_context->conv1, config_file_context), avio_closep(&config_file_context))
     162             : 
     163           0 :         srcnn_context->conv2.input_channels = (int32_t)avio_rl32(config_file_context);
     164           0 :         srcnn_context->conv2.size = (int32_t)avio_rl32(config_file_context);
     165           0 :         srcnn_context->conv2.output_channels = (int32_t)avio_rl32(config_file_context);
     166           0 :         srcnn_size += 12 + (srcnn_context->conv2.output_channels * srcnn_context->conv2.size *
     167           0 :                             srcnn_context->conv2.size * srcnn_context->conv2.input_channels +
     168           0 :                             srcnn_context->conv2.output_channels << 3);
     169           0 :         CHECK_FILE_SIZE(file_size, srcnn_size, config_file_context)
     170           0 :         CHECK_ALLOCATION(allocate_read_conv_data(&srcnn_context->conv2, config_file_context), avio_closep(&config_file_context))
     171             : 
     172           0 :         srcnn_context->conv3.input_channels = (int32_t)avio_rl32(config_file_context);
     173           0 :         srcnn_context->conv3.size = (int32_t)avio_rl32(config_file_context);
     174           0 :         srcnn_context->conv3.output_channels = 1;
     175           0 :         srcnn_size += 8 + (srcnn_context->conv3.output_channels * srcnn_context->conv3.size *
     176           0 :                            srcnn_context->conv3.size * srcnn_context->conv3.input_channels
     177           0 :                            + srcnn_context->conv3.output_channels << 3);
     178           0 :         if (file_size != srcnn_size){
     179           0 :             av_log(context, AV_LOG_ERROR, "error reading configuration file\n");
     180           0 :             avio_closep(&config_file_context);
     181           0 :             return AVERROR(EIO);
     182             :         }
     183           0 :         CHECK_ALLOCATION(allocate_read_conv_data(&srcnn_context->conv3, config_file_context), avio_closep(&config_file_context))
     184             : 
     185           0 :         avio_closep(&config_file_context);
     186             :     }
     187             :     else{
     188           0 :         av_log(context, AV_LOG_ERROR, "specified configuration file does not exist or not readable\n");
     189           0 :         return AVERROR(EIO);
     190             :     }
     191             : 
     192           0 :     return 0;
     193             : }
     194             : 
     195           0 : static int query_formats(AVFilterContext* context)
     196             : {
     197           0 :     const enum AVPixelFormat pixel_formats[] = {AV_PIX_FMT_YUV420P, AV_PIX_FMT_YUV422P, AV_PIX_FMT_YUV444P,
     198             :                                                 AV_PIX_FMT_YUV410P, AV_PIX_FMT_YUV411P, AV_PIX_FMT_GRAY8,
     199             :                                                 AV_PIX_FMT_NONE};
     200             :     AVFilterFormats *formats_list;
     201             : 
     202           0 :     formats_list = ff_make_format_list(pixel_formats);
     203           0 :     if (!formats_list){
     204           0 :         av_log(context, AV_LOG_ERROR, "could not create formats list\n");
     205           0 :         return AVERROR(ENOMEM);
     206             :     }
     207           0 :     return ff_set_common_formats(context, formats_list);
     208             : }
     209             : 
     210           0 : static int config_props(AVFilterLink* inlink)
     211             : {
     212           0 :     AVFilterContext *context = inlink->dst;
     213           0 :     SRCNNContext *srcnn_context = context->priv;
     214             :     int min_dim;
     215             : 
     216             :     /// Check if input data width or height is too low
     217           0 :     min_dim = FFMIN(inlink->w, inlink->h);
     218           0 :     if (min_dim <= srcnn_context->conv1.size >> 1 || min_dim <= srcnn_context->conv2.size >> 1 || min_dim <= srcnn_context->conv3.size >> 1){
     219           0 :         av_log(context, AV_LOG_ERROR, "input width or height is too low\n");
     220           0 :         return AVERROR(EIO);
     221             :     }
     222             : 
     223             :     /// Allocate network buffers
     224           0 :     srcnn_context->input_output_buf = av_malloc(inlink->h * inlink->w * sizeof(double));
     225           0 :     srcnn_context->conv1_buf = av_malloc(inlink->h * inlink->w * srcnn_context->conv1.output_channels * sizeof(double));
     226           0 :     srcnn_context->conv2_buf = av_malloc(inlink->h * inlink->w * srcnn_context->conv2.output_channels * sizeof(double));
     227             : 
     228           0 :     if (!srcnn_context->input_output_buf || !srcnn_context->conv1_buf || !srcnn_context->conv2_buf){
     229           0 :         av_log(context, AV_LOG_ERROR, "could not allocate memory for srcnn buffers\n");
     230           0 :         return AVERROR(ENOMEM);
     231             :     }
     232             : 
     233           0 :     return 0;
     234             : }
     235             : 
     236             : typedef struct ThreadData{
     237             :     uint8_t* out;
     238             :     int out_linesize, height, width;
     239             : } ThreadData;
     240             : 
     241             : typedef struct ConvThreadData
     242             : {
     243             :     const Convolution* conv;
     244             :     const double* input;
     245             :     double* output;
     246             :     int height, width;
     247             : } ConvThreadData;
     248             : 
     249             : /// Convert uint8 data to double and scale it to use in network
     250           0 : static int uint8_to_double(AVFilterContext* context, void* arg, int jobnr, int nb_jobs)
     251             : {
     252           0 :     SRCNNContext* srcnn_context = context->priv;
     253           0 :     const ThreadData* td = arg;
     254           0 :     const int slice_start = (td->height *  jobnr     ) / nb_jobs;
     255           0 :     const int slice_end   = (td->height * (jobnr + 1)) / nb_jobs;
     256           0 :     const uint8_t* src = td->out + slice_start * td->out_linesize;
     257           0 :     double* dst = srcnn_context->input_output_buf + slice_start * td->width;
     258             :     int y, x;
     259             : 
     260           0 :     for (y = slice_start; y < slice_end; ++y){
     261           0 :         for (x = 0; x < td->width; ++x){
     262           0 :             dst[x] = (double)src[x] / 255.0;
     263             :         }
     264           0 :         src += td->out_linesize;
     265           0 :         dst += td->width;
     266             :     }
     267             : 
     268           0 :     return 0;
     269             : }
     270             : 
     271             : /// Convert double data from network to uint8 and scale it to output as filter result
     272           0 : static int double_to_uint8(AVFilterContext* context, void* arg, int jobnr, int nb_jobs)
     273             : {
     274           0 :     SRCNNContext* srcnn_context = context->priv;
     275           0 :     const ThreadData* td = arg;
     276           0 :     const int slice_start = (td->height *  jobnr     ) / nb_jobs;
     277           0 :     const int slice_end   = (td->height * (jobnr + 1)) / nb_jobs;
     278           0 :     const double* src = srcnn_context->input_output_buf + slice_start * td->width;
     279           0 :     uint8_t* dst = td->out + slice_start * td->out_linesize;
     280             :     int y, x;
     281             : 
     282           0 :     for (y = slice_start; y < slice_end; ++y){
     283           0 :         for (x = 0; x < td->width; ++x){
     284           0 :             dst[x] = (uint8_t)(255.0 * FFMIN(src[x], 1.0));
     285             :         }
     286           0 :         src += td->width;
     287           0 :         dst += td->out_linesize;
     288             :     }
     289             : 
     290           0 :     return 0;
     291             : }
     292             : 
     293             : #define CLAMP_TO_EDGE(x, w) ((x) < 0 ? 0 : ((x) >= (w) ? (w - 1) : (x)))
     294             : 
     295           0 : static int convolve(AVFilterContext* context, void* arg, int jobnr, int nb_jobs)
     296             : {
     297           0 :     const ConvThreadData* td = arg;
     298           0 :     const int slice_start = (td->height *  jobnr     ) / nb_jobs;
     299           0 :     const int slice_end   = (td->height * (jobnr + 1)) / nb_jobs;
     300           0 :     const double* src = td->input;
     301           0 :     double* dst = td->output + slice_start * td->width * td->conv->output_channels;
     302             :     int y, x;
     303             :     int32_t n_filter, ch, kernel_y, kernel_x;
     304           0 :     int32_t radius = td->conv->size >> 1;
     305           0 :     int src_linesize = td->width * td->conv->input_channels;
     306           0 :     int filter_linesize = td->conv->size * td->conv->input_channels;
     307           0 :     int filter_size = td->conv->size * filter_linesize;
     308             : 
     309           0 :     for (y = slice_start; y < slice_end; ++y){
     310           0 :         for (x = 0; x < td->width; ++x){
     311           0 :             for (n_filter = 0; n_filter < td->conv->output_channels; ++n_filter){
     312           0 :                 dst[n_filter] = td->conv->biases[n_filter];
     313           0 :                 for (ch = 0; ch < td->conv->input_channels; ++ch){
     314           0 :                     for (kernel_y = 0; kernel_y < td->conv->size; ++kernel_y){
     315           0 :                         for (kernel_x = 0; kernel_x < td->conv->size; ++kernel_x){
     316           0 :                             dst[n_filter] += src[CLAMP_TO_EDGE(y + kernel_y - radius, td->height) * src_linesize +
     317           0 :                                                  CLAMP_TO_EDGE(x + kernel_x - radius, td->width) * td->conv->input_channels + ch] *
     318           0 :                                              td->conv->kernel[n_filter * filter_size + kernel_y * filter_linesize +
     319           0 :                                                               kernel_x * td->conv->input_channels + ch];
     320             :                         }
     321             :                     }
     322             :                 }
     323           0 :                 dst[n_filter] = FFMAX(dst[n_filter], 0.0);
     324             :             }
     325           0 :             dst += td->conv->output_channels;
     326             :         }
     327             :     }
     328             : 
     329           0 :     return 0;
     330             : }
     331             : 
     332           0 : static int filter_frame(AVFilterLink* inlink, AVFrame* in)
     333             : {
     334           0 :     AVFilterContext* context = inlink->dst;
     335           0 :     SRCNNContext* srcnn_context = context->priv;
     336           0 :     AVFilterLink* outlink = context->outputs[0];
     337           0 :     AVFrame* out = ff_get_video_buffer(outlink, outlink->w, outlink->h);
     338             :     ThreadData td;
     339             :     ConvThreadData ctd;
     340             :     int nb_threads;
     341             : 
     342           0 :     if (!out){
     343           0 :         av_log(context, AV_LOG_ERROR, "could not allocate memory for output frame\n");
     344           0 :         av_frame_free(&in);
     345           0 :         return AVERROR(ENOMEM);
     346             :     }
     347           0 :     av_frame_copy_props(out, in);
     348           0 :     av_frame_copy(out, in);
     349           0 :     av_frame_free(&in);
     350           0 :     td.out = out->data[0];
     351           0 :     td.out_linesize = out->linesize[0];
     352           0 :     td.height = ctd.height = out->height;
     353           0 :     td.width = ctd.width = out->width;
     354             : 
     355           0 :     nb_threads = ff_filter_get_nb_threads(context);
     356           0 :     context->internal->execute(context, uint8_to_double, &td, NULL, FFMIN(td.height, nb_threads));
     357           0 :     ctd.conv = &srcnn_context->conv1;
     358           0 :     ctd.input = srcnn_context->input_output_buf;
     359           0 :     ctd.output = srcnn_context->conv1_buf;
     360           0 :     context->internal->execute(context, convolve, &ctd, NULL, FFMIN(ctd.height, nb_threads));
     361           0 :     ctd.conv = &srcnn_context->conv2;
     362           0 :     ctd.input = srcnn_context->conv1_buf;
     363           0 :     ctd.output = srcnn_context->conv2_buf;
     364           0 :     context->internal->execute(context, convolve, &ctd, NULL, FFMIN(ctd.height, nb_threads));
     365           0 :     ctd.conv = &srcnn_context->conv3;
     366           0 :     ctd.input = srcnn_context->conv2_buf;
     367           0 :     ctd.output = srcnn_context->input_output_buf;
     368           0 :     context->internal->execute(context, convolve, &ctd, NULL, FFMIN(ctd.height, nb_threads));
     369           0 :     context->internal->execute(context, double_to_uint8, &td, NULL, FFMIN(td.height, nb_threads));
     370             : 
     371           0 :     return ff_filter_frame(outlink, out);
     372             : }
     373             : 
     374           0 : static av_cold void uninit(AVFilterContext* context)
     375             : {
     376           0 :     SRCNNContext* srcnn_context = context->priv;
     377             : 
     378             :     /// Free convolution data
     379           0 :     av_freep(&srcnn_context->conv1.kernel);
     380           0 :     av_freep(&srcnn_context->conv1.biases);
     381           0 :     av_freep(&srcnn_context->conv2.kernel);
     382           0 :     av_freep(&srcnn_context->conv2.biases);
     383           0 :     av_freep(&srcnn_context->conv3.kernel);
     384           0 :     av_freep(&srcnn_context->conv3.kernel);
     385             : 
     386             :     /// Free network buffers
     387           0 :     av_freep(&srcnn_context->input_output_buf);
     388           0 :     av_freep(&srcnn_context->conv1_buf);
     389           0 :     av_freep(&srcnn_context->conv2_buf);
     390           0 : }
     391             : 
     392             : static const AVFilterPad srcnn_inputs[] = {
     393             :     {
     394             :         .name         = "default",
     395             :         .type         = AVMEDIA_TYPE_VIDEO,
     396             :         .config_props = config_props,
     397             :         .filter_frame = filter_frame,
     398             :     },
     399             :     { NULL }
     400             : };
     401             : 
     402             : static const AVFilterPad srcnn_outputs[] = {
     403             :     {
     404             :         .name = "default",
     405             :         .type = AVMEDIA_TYPE_VIDEO,
     406             :     },
     407             :     { NULL }
     408             : };
     409             : 
     410             : AVFilter ff_vf_srcnn = {
     411             :     .name          = "srcnn",
     412             :     .description   = NULL_IF_CONFIG_SMALL("Apply super resolution convolutional neural network to the input. Use bicubic upsamping with corresponding scaling factor before."),
     413             :     .priv_size     = sizeof(SRCNNContext),
     414             :     .init          = init,
     415             :     .uninit        = uninit,
     416             :     .query_formats = query_formats,
     417             :     .inputs        = srcnn_inputs,
     418             :     .outputs       = srcnn_outputs,
     419             :     .priv_class    = &srcnn_class,
     420             :     .flags         = AVFILTER_FLAG_SUPPORT_TIMELINE_GENERIC | AVFILTER_FLAG_SLICE_THREADS,
     421             : };

Generated by: LCOV version 1.13