forked from google/ExoPlayer
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add MssimCalculatorTest to verify SSIM calculations.
As part of this change, MssimCalculator is moved from androidTest/ to main/ PiperOrigin-RevId: 473771344
- Loading branch information
1 parent
696ef2a
commit 20aa22c
Showing
3 changed files
with
287 additions
and
180 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
191 changes: 191 additions & 0 deletions
191
.../transformer/src/main/java/com/google/android/exoplayer2/transformer/MssimCalculator.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
/* | ||
* Copyright 2022 The Android Open Source Project | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package com.google.android.exoplayer2.transformer; | ||
|
||
import static java.lang.Math.pow; | ||
|
||
/** | ||
* Image comparison tool that calculates the Mean Structural Similarity (MSSIM) of two images, | ||
* developed by Wang, Bovik, Sheikh, and Simoncelli. | ||
* | ||
* <p>MSSIM divides the image into windows, calculates SSIM of each, then returns the average. | ||
* | ||
* @see <a href=https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf>The SSIM paper</a>. | ||
*/ | ||
/* package */ final class MssimCalculator { | ||
// Referred to as 'L' in the SSIM paper, this constant defines the maximum pixel values. The | ||
// range of pixel values is 0 to 255 (8 bit unsigned range). | ||
private static final int PIXEL_MAX_VALUE = 255; | ||
|
||
// K1 and K2, as defined in the SSIM paper. | ||
private static final double K1 = 0.01; | ||
private static final double K2 = 0.03; | ||
|
||
// C1 and C2 stabilize the SSIM value when either (referenceMean^2 + distortedMean^2) or | ||
// (referenceVariance + distortedVariance) is close to 0. See the SSIM formula in | ||
// `getWindowSsim` for how these values impact each other in the calculation. | ||
private static final double C1 = pow(PIXEL_MAX_VALUE * K1, 2); | ||
private static final double C2 = pow(PIXEL_MAX_VALUE * K2, 2); | ||
|
||
private static final int WINDOW_SIZE = 8; | ||
|
||
private MssimCalculator() {} | ||
|
||
/** | ||
* Calculates the Mean Structural Similarity (MSSIM) between two images. | ||
* | ||
* @param referenceBuffer The luma channel (Y) buffer of the reference image. | ||
* @param distortedBuffer The luma channel (Y) buffer of the distorted image. | ||
* @param width The image width in pixels. | ||
* @param height The image height in pixels. | ||
* @return The MSSIM score between the input images. | ||
*/ | ||
public static double calculate( | ||
byte[] referenceBuffer, byte[] distortedBuffer, int width, int height) { | ||
double totalSsim = 0; | ||
int windowsCount = 0; | ||
|
||
for (int currentWindowY = 0; currentWindowY < height; currentWindowY += WINDOW_SIZE) { | ||
int windowHeight = computeWindowSize(currentWindowY, height); | ||
for (int currentWindowX = 0; currentWindowX < width; currentWindowX += WINDOW_SIZE) { | ||
windowsCount++; | ||
int windowWidth = computeWindowSize(currentWindowX, width); | ||
int bufferIndexOffset = | ||
get1dIndex(currentWindowX, currentWindowY, /* stride= */ width, /* offset= */ 0); | ||
double referenceMean = | ||
getMean( | ||
referenceBuffer, bufferIndexOffset, /* stride= */ width, windowWidth, windowHeight); | ||
double distortedMean = | ||
getMean( | ||
distortedBuffer, bufferIndexOffset, /* stride= */ width, windowWidth, windowHeight); | ||
|
||
double[] variances = | ||
getVariancesAndCovariance( | ||
referenceBuffer, | ||
distortedBuffer, | ||
referenceMean, | ||
distortedMean, | ||
bufferIndexOffset, | ||
/* stride= */ width, | ||
windowWidth, | ||
windowHeight); | ||
double referenceVariance = variances[0]; | ||
double distortedVariance = variances[1]; | ||
double referenceDistortedCovariance = variances[2]; | ||
|
||
totalSsim += | ||
getWindowSsim( | ||
referenceMean, | ||
distortedMean, | ||
referenceVariance, | ||
distortedVariance, | ||
referenceDistortedCovariance); | ||
} | ||
} | ||
|
||
if (windowsCount == 0) { | ||
return 1.0d; | ||
} | ||
|
||
return totalSsim / windowsCount; | ||
} | ||
|
||
/** | ||
* Returns the window size at the provided start coordinate, uses {@link #WINDOW_SIZE} if there is | ||
* enough space, otherwise the number of pixels between {@code start} and {@code dimension}. | ||
*/ | ||
private static int computeWindowSize(int start, int dimension) { | ||
if (start + WINDOW_SIZE <= dimension) { | ||
return WINDOW_SIZE; | ||
} | ||
return dimension - start; | ||
} | ||
|
||
/** Returns the SSIM of a window. */ | ||
private static double getWindowSsim( | ||
double referenceMean, | ||
double distortedMean, | ||
double referenceVariance, | ||
double distortedVariance, | ||
double referenceDistortedCovariance) { | ||
|
||
// Uses equation 13 on page 6 from the linked paper. | ||
double numerator = | ||
(((2 * referenceMean * distortedMean) + C1) * ((2 * referenceDistortedCovariance) + C2)); | ||
double denominator = | ||
((referenceMean * referenceMean) + (distortedMean * distortedMean) + C1) | ||
* (referenceVariance + distortedVariance + C2); | ||
return numerator / denominator; | ||
} | ||
|
||
/** Returns the mean of the pixels in the window. */ | ||
private static double getMean( | ||
byte[] pixelBuffer, int bufferIndexOffset, int stride, int windowWidth, int windowHeight) { | ||
double total = 0; | ||
for (int y = 0; y < windowHeight; y++) { | ||
for (int x = 0; x < windowWidth; x++) { | ||
total += pixelBuffer[get1dIndex(x, y, stride, bufferIndexOffset)] & 0xFF; | ||
} | ||
} | ||
return total / (windowWidth * windowHeight); | ||
} | ||
|
||
/** Calculates the variances and covariance of the pixels in the window for both buffers. */ | ||
private static double[] getVariancesAndCovariance( | ||
byte[] referenceBuffer, | ||
byte[] distortedBuffer, | ||
double referenceMean, | ||
double distortedMean, | ||
int bufferIndexOffset, | ||
int stride, | ||
int windowWidth, | ||
int windowHeight) { | ||
double referenceVariance = 0; | ||
double distortedVariance = 0; | ||
double referenceDistortedCovariance = 0; | ||
for (int y = 0; y < windowHeight; y++) { | ||
for (int x = 0; x < windowWidth; x++) { | ||
int index = get1dIndex(x, y, stride, bufferIndexOffset); | ||
double referencePixelDeviation = (referenceBuffer[index] & 0xFF) - referenceMean; | ||
double distortedPixelDeviation = (distortedBuffer[index] & 0xFF) - distortedMean; | ||
referenceVariance += referencePixelDeviation * referencePixelDeviation; | ||
distortedVariance += distortedPixelDeviation * distortedPixelDeviation; | ||
referenceDistortedCovariance += referencePixelDeviation * distortedPixelDeviation; | ||
} | ||
} | ||
|
||
int normalizationFactor = windowWidth * windowHeight - 1; | ||
|
||
return new double[] { | ||
referenceVariance / normalizationFactor, | ||
distortedVariance / normalizationFactor, | ||
referenceDistortedCovariance / normalizationFactor | ||
}; | ||
} | ||
|
||
/** | ||
* Translates a 2D coordinate into an 1D index, based on the stride of the 2D space. | ||
* | ||
* @param x The width component of coordinate. | ||
* @param y The height component of coordinate. | ||
* @param stride The width of the 2D space. | ||
* @param offset An offset to apply. | ||
* @return The 1D index. | ||
*/ | ||
private static int get1dIndex(int x, int y, int stride, int offset) { | ||
return x + (y * stride) + offset; | ||
} | ||
} |
Oops, something went wrong.