Thursday, December 20, 2018

LCECBF: Linear Cost Exact Circular Bokeh Filter

Depth-of-field has always been a costly post process for video games. Particularly, a circle-of-confusion filter proved to be a bottleneck. While being deceptively simple (all weights are either zero or constant values), it is non-separable: it means that, unlike box or Gaussian filter, one cannot run a two-pass O(r) filter (where r is filter radius), and rather needs to run O(r * r) pass, in order to get accurate results.

There have been a lot of efforts to solve this issue. I apologize for being lazy to make a proper bibliographical reference, so I will simply list the approaches:
  • NFS approach: works only for polygonal (e.g., hexagonal) bokeh - make a horizontal and then 2 diagonal passes, combine with a max filter (has some artifacts)
  • Crytek 2-pass approach with rotating kernel and flood fill
  • recent Fourier-series approach from EA (published at GDC18)
While being practical, these approaches still lack the following qualities:
  • being accurate
  • requiring only one pass/no additional render targets
The approach I want to present here can apply a circular (or, basically, of any convex shape) bokeh filter to an image:
  • with number of samples which is O(r), where r is filter radius
  • matching ground truth up to floating point error
  • single pass
  • not requiring any additional memory allocation
The proposed method utilizes an idea that was hinted to me about 10 years by my supervisor, Alexey Ignatenko, so here is a shout out to him. The key idea is that once you've computed a convolution with a constant-value filter kernel for one point, you can reuse most of it for the neighbor point:
If we compute convolution for the point A, effectively, we can reuse most of it (blue part) for the point B. This is due to fact that the weights are the same - the intrinsic property of the bokeh filter kernel. Now, to compute the convolution at B, all we have to do is take the A's convolution, add all pixels in green and subtract all pixels in pink.

How many pixels would this be? Perimeter-many :) For a circle (and regular polygons) this is a linear function of their radius.

How does this translate to shader code? Well, effectively, we can compute a full convolution just for one pixel, and then propagate it to neighbors at linear cost. That would require compute shader threads to output more than pixel, naturally.

Here is the shader code (I apologize for hardcoded constants and other less-than-ideal things):

Texture2D Input : register( t0 );
RWTexture2D<float4> Result : register( u0 );


[numthreads( 1, 32, 1 )]
void CSMain( uint3 Gid : SV_GroupID, uint GI : SV_GroupIndex, uint3 DTid : SV_DispatchThreadID )
{
    const int nTotalSamples = 317;
    const int2 vCircleSamples[317] =
    {
        int2(-10, 0),     int2(-9, -4),     int2(-9, -3),     int2(-9, -2),     int2(-9, -1),     int2(-9, 0),     int2(-9, 1),     int2(-9, 2),     int2(-9, 3),     int2(-9, 4),
        int2(-8, -6),     int2(-8, -5),     int2(-8, -4),     int2(-8, -3),     int2(-8, -2),     int2(-8, -1),     int2(-8, 0),     int2(-8, 1),     int2(-8, 2),     int2(-8, 3),
        int2(-8, 4),     int2(-8, 5),     int2(-8, 6),     int2(-7, -7),     int2(-7, -6),     int2(-7, -5),     int2(-7, -4),     int2(-7, -3),     int2(-7, -2),     int2(-7, -1),
        int2(-7, 0),     int2(-7, 1),     int2(-7, 2),     int2(-7, 3),     int2(-7, 4),     int2(-7, 5),     int2(-7, 6),     int2(-7, 7),     int2(-6, -8),     int2(-6, -7),
        int2(-6, -6),     int2(-6, -5),     int2(-6, -4),     int2(-6, -3),     int2(-6, -2),     int2(-6, -1),     int2(-6, 0),     int2(-6, 1),     int2(-6, 2),     int2(-6, 3),
        int2(-6, 4),     int2(-6, 5),     int2(-6, 6),     int2(-6, 7),     int2(-6, 8),     int2(-5, -8),     int2(-5, -7),     int2(-5, -6),     int2(-5, -5),     int2(-5, -4),
        int2(-5, -3),     int2(-5, -2),     int2(-5, -1),     int2(-5, 0),     int2(-5, 1),     int2(-5, 2),     int2(-5, 3),     int2(-5, 4),     int2(-5, 5),     int2(-5, 6),
        int2(-5, 7),     int2(-5, 8),     int2(-4, -9),     int2(-4, -8),     int2(-4, -7),     int2(-4, -6),     int2(-4, -5),     int2(-4, -4),     int2(-4, -3),     int2(-4, -2),
        int2(-4, -1),     int2(-4, 0),     int2(-4, 1),     int2(-4, 2),     int2(-4, 3),     int2(-4, 4),     int2(-4, 5),     int2(-4, 6),     int2(-4, 7),     int2(-4, 8),
        int2(-4, 9),     int2(-3, -9),     int2(-3, -8),     int2(-3, -7),     int2(-3, -6),     int2(-3, -5),     int2(-3, -4),     int2(-3, -3),     int2(-3, -2),     int2(-3, -1),
        int2(-3, 0),     int2(-3, 1),     int2(-3, 2),     int2(-3, 3),     int2(-3, 4),     int2(-3, 5),     int2(-3, 6),     int2(-3, 7),     int2(-3, 8),     int2(-3, 9),
        int2(-2, -9),     int2(-2, -8),     int2(-2, -7),     int2(-2, -6),     int2(-2, -5),     int2(-2, -4),     int2(-2, -3),     int2(-2, -2),     int2(-2, -1),     int2(-2, 0),
        int2(-2, 1),     int2(-2, 2),     int2(-2, 3),     int2(-2, 4),     int2(-2, 5),     int2(-2, 6),     int2(-2, 7),     int2(-2, 8),     int2(-2, 9),     int2(-1, -9),
        int2(-1, -8),     int2(-1, -7),     int2(-1, -6),     int2(-1, -5),     int2(-1, -4),     int2(-1, -3),     int2(-1, -2),     int2(-1, -1),     int2(-1, 0),     int2(-1, 1),
        int2(-1, 2),     int2(-1, 3),     int2(-1, 4),     int2(-1, 5),     int2(-1, 6),     int2(-1, 7),     int2(-1, 8),     int2(-1, 9),     int2(0, -10),     int2(0, -9),
        int2(0, -8),     int2(0, -7),     int2(0, -6),     int2(0, -5),     int2(0, -4),     int2(0, -3),     int2(0, -2),     int2(0, -1),     int2(0, 0),     int2(0, 1),
        int2(0, 2),     int2(0, 3),     int2(0, 4),     int2(0, 5),     int2(0, 6),     int2(0, 7),     int2(0, 8),     int2(0, 9),     int2(0, 10),     int2(1, -9),
        int2(1, -8),     int2(1, -7),     int2(1, -6),     int2(1, -5),     int2(1, -4),     int2(1, -3),     int2(1, -2),     int2(1, -1),     int2(1, 0),     int2(1, 1),
        int2(1, 2),     int2(1, 3),     int2(1, 4),     int2(1, 5),     int2(1, 6),     int2(1, 7),     int2(1, 8),     int2(1, 9),     int2(2, -9),     int2(2, -8),
        int2(2, -7),     int2(2, -6),     int2(2, -5),     int2(2, -4),     int2(2, -3),     int2(2, -2),     int2(2, -1),     int2(2, 0),     int2(2, 1),     int2(2, 2),
        int2(2, 3),     int2(2, 4),     int2(2, 5),     int2(2, 6),     int2(2, 7),     int2(2, 8),     int2(2, 9),     int2(3, -9),     int2(3, -8),     int2(3, -7),
        int2(3, -6),     int2(3, -5),     int2(3, -4),     int2(3, -3),     int2(3, -2),     int2(3, -1),     int2(3, 0),     int2(3, 1),     int2(3, 2),     int2(3, 3),
        int2(3, 4),     int2(3, 5),     int2(3, 6),     int2(3, 7),     int2(3, 8),     int2(3, 9),     int2(4, -9),     int2(4, -8),     int2(4, -7),     int2(4, -6),
        int2(4, -5),     int2(4, -4),     int2(4, -3),     int2(4, -2),     int2(4, -1),     int2(4, 0),     int2(4, 1),     int2(4, 2),     int2(4, 3),     int2(4, 4),
        int2(4, 5),     int2(4, 6),     int2(4, 7),     int2(4, 8),     int2(4, 9),     int2(5, -8),     int2(5, -7),     int2(5, -6),     int2(5, -5),     int2(5, -4),
        int2(5, -3),     int2(5, -2),     int2(5, -1),     int2(5, 0),     int2(5, 1),     int2(5, 2),     int2(5, 3),     int2(5, 4),     int2(5, 5),     int2(5, 6),
        int2(5, 7),     int2(5, 8),     int2(6, -8),     int2(6, -7),     int2(6, -6),     int2(6, -5),     int2(6, -4),     int2(6, -3),     int2(6, -2),     int2(6, -1),
        int2(6, 0),     int2(6, 1),     int2(6, 2),     int2(6, 3),     int2(6, 4),     int2(6, 5),     int2(6, 6),     int2(6, 7),     int2(6, 8),     int2(7, -7),
        int2(7, -6),     int2(7, -5),     int2(7, -4),     int2(7, -3),     int2(7, -2),     int2(7, -1),     int2(7, 0),     int2(7, 1),     int2(7, 2),     int2(7, 3),
        int2(7, 4),     int2(7, 5),     int2(7, 6),     int2(7, 7),     int2(8, -6),     int2(8, -5),     int2(8, -4),     int2(8, -3),     int2(8, -2),     int2(8, -1),
        int2(8, 0),     int2(8, 1),     int2(8, 2),     int2(8, 3),     int2(8, 4),     int2(8, 5),     int2(8, 6),     int2(9, -4),     int2(9, -3),     int2(9, -2),
        int2(9, -1),     int2(9, 0),     int2(9, 1),     int2(9, 2),     int2(9, 3),     int2(9, 4),     int2(10, 0)
    };

    const int2 vCircleSamplesNeg[] =
    {
        int2(-11, 0),     int2(-10, -4),     int2(-10, -3),     int2(-10, -2),     int2(-10, -1),     int2(-10, 1),     int2(-10, 2),     int2(-10, 3),     int2(-10, 4),     int2(-9, -6),
        int2(-9, -5),     int2(-9, 5),     int2(-9, 6),     int2(-8, -7),     int2(-8, 7),     int2(-7, -8),     int2(-7, 8),     int2(-5, -9),     int2(-5, 9),     int2(-1, -10),
        int2(-1, 10)
    };
    const int totalSamplesBorderNeg = 21;

    const int2 vCircleSamplesPos[] =
    {
        int2(0, -10),     int2(0, 10),     int2(4, -9),     int2(4, 9),     int2(6, -8),     int2(6, 8),     int2(7, -7),     int2(7, 7),     int2(8, -6),     int2(8, -5),
        int2(8, 5),     int2(8, 6),     int2(9, -4),     int2(9, -3),     int2(9, -2),     int2(9, -1),     int2(9, 1),     int2(9, 2),     int2(9, 3),     int2(9, 4),
        int2(10, 0)
    };
    const int totalSamplesBorderPos = 21;
    

    float4 res = 0;
    int2 coord = int2(DTid.x * 64, DTid.y);
    [loop]
    for (int s = 0; s < nTotalSamples; ++s)
    {
        res += Input[coord + vCircleSamples[s]];
    }
    res /= float(nTotalSamples);
    Result[coord] = res;
    float4 prevRes = res;

    [loop]
    for (int i = 1; i < 64; ++i)
    {
        res = 0;
        coord = int2(DTid.x * 64 + i, DTid.y);
        [loop]
        for (int s = 0; s < totalSamplesBorderNeg; ++s)
        {
            res -= Input[coord + vCircleSamplesNeg[s]];
        }
        [loop]
        for (int s = 0; s < totalSamplesBorderPos; ++s)
        {
            res += Input[coord + vCircleSamplesPos[s]];
        }

        res /= float(nTotalSamples);
        res += prevRes;
        Result[coord] = res;
        prevRes = res;
    }   
}


And here is the code for the ground truth (all samples) bokeh computation:
float4 res = 0;
int2 coord = int2(DTid.x, DTid.y);
[loop]
for (int s = 0; s < nTotalSamples; ++s)
{
    res += Input[coord + vCircleSamples[s]];
}

res /= float(nTotalSamples);
Result[coord] = res;


Here are the dispatch calls, just in case:
pd3dImmediateContext->Dispatch((width + 63) / 64, (height + 31) / 32, 1); // Proposed method
pd3dImmediateContext->Dispatch(width, (height + 31) / 32, 1); // Ground truth


Performance (I didn't do any thorough optimization, e.g., half-res, making fetches more cache-friendly, etc), numbers for 21x21 filter, Full HD, GeForce 1060, RGBA16_FLOAT src/dst:
Ground Truth - 5ms
Proposed Method - 1.7ms

Visual results (apologize for not handling borders properly, again kinda lazy):

UPD: closeup for smartphone users :)

Original:

Ground Truth Filter:

Proposed Method Filter:

Obviously, there is yet a lot of optimization and polishing to be done to make it production-ready, but I think the concept is original and worth digging (and I like how it exploits bokeh filter kernel constant weight property).

Please, let me know (in the comments, Twitter, private messages etc) if you want me to polish&publish a demo. If there are enough people interested, I will try to find time for that :)