diff --git a/CHANGELOG.md b/CHANGELOG.md index 92e765fe22..e5b17e8d51 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ### Changed +- Improve default settings of EfficientAD + ### Deprecated ### Fixed diff --git a/README.md b/README.md index d5ec52a47f..6147febecc 100644 --- a/README.md +++ b/README.md @@ -107,6 +107,7 @@ where the currently available models are: - [DFKDE](src/anomalib/models/dfkde) - [DFM](src/anomalib/models/dfm) - [DRAEM](src/anomalib/models/draem) +- [EfficientAD](src/anomalib/models/efficientad) - [FastFlow](src/anomalib/models/fastflow) - [GANomaly](src/anomalib/models/ganomaly) - [PADIM](src/anomalib/models/padim) @@ -285,36 +286,40 @@ Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License ## Image-Level AUC -| Model | | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper | -| ------------- | ------------------ | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-----: | :-------: | :-------: | :------: | :-------: | :-------: | :-------: | :--------: | :--------: | :-------: | -| **PatchCore** | **Wide ResNet-50** | **0.980** | 0.984 | 0.959 | 1.000 | **1.000** | 0.989 | 1.000 | **0.990** | **0.982** | 1.000 | 0.994 | 0.924 | 0.960 | 0.933 | **1.000** | 0.982 | -| PatchCore | ResNet-18 | 0.973 | 0.970 | 0.947 | 1.000 | 0.997 | 0.997 | 1.000 | 0.986 | 0.965 | 1.000 | 0.991 | 0.916 | **0.943** | 0.931 | 0.996 | 0.953 | -| CFlow | Wide ResNet-50 | 0.962 | 0.986 | 0.962 | **1.000** | 0.999 | 0.993 | **1.0** | 0.893 | 0.945 | **1.0** | **0.995** | 0.924 | 0.908 | 0.897 | 0.943 | **0.984** | -| CFA | Wide ResNet-50 | 0.956 | 0.978 | 0.961 | 0.990 | 0.999 | 0.994 | 0.998 | 0.979 | 0.872 | 1.000 | **0.995** | **0.946** | 0.703 | **1.000** | 0.957 | 0.967 | -| CFA | ResNet-18 | 0.930 | 0.953 | 0.947 | 0.999 | 1.000 | **1.000** | 0.991 | 0.947 | 0.858 | 0.995 | 0.932 | 0.887 | 0.625 | 0.994 | 0.895 | 0.919 | -| PaDiM | Wide ResNet-50 | 0.950 | **0.995** | 0.942 | **1.000** | 0.974 | 0.993 | 0.999 | 0.878 | 0.927 | 0.964 | 0.989 | 0.939 | 0.845 | 0.942 | 0.976 | 0.882 | -| PaDiM | ResNet-18 | 0.891 | 0.945 | 0.857 | 0.982 | 0.950 | 0.976 | 0.994 | 0.844 | 0.901 | 0.750 | 0.961 | 0.863 | 0.759 | 0.889 | 0.920 | 0.780 | -| DFM | Wide ResNet-50 | 0.943 | 0.855 | 0.784 | 0.997 | 0.995 | 0.975 | 0.999 | 0.969 | 0.924 | 0.978 | 0.939 | 0.962 | 0.873 | 0.969 | 0.971 | 0.961 | -| DFM | ResNet-18 | 0.936 | 0.817 | 0.736 | 0.993 | 0.966 | 0.977 | 1.000 | 0.956 | 0.944 | 0.994 | 0.922 | 0.961 | 0.89 | 0.969 | 0.939 | 0.969 | -| STFPM | Wide ResNet-50 | 0.876 | 0.957 | 0.977 | 0.981 | 0.976 | 0.939 | 0.987 | 0.878 | 0.732 | 0.995 | 0.973 | 0.652 | 0.825 | 0.500 | 0.875 | 0.899 | -| STFPM | ResNet-18 | 0.893 | 0.954 | **0.982** | 0.989 | 0.949 | 0.961 | 0.979 | 0.838 | 0.759 | 0.999 | 0.956 | 0.705 | 0.835 | **0.997** | 0.853 | 0.645 | -| DFKDE | Wide ResNet-50 | 0.774 | 0.708 | 0.422 | 0.905 | 0.959 | 0.903 | 0.936 | 0.746 | 0.853 | 0.736 | 0.687 | 0.749 | 0.574 | 0.697 | 0.843 | 0.892 | -| DFKDE | ResNet-18 | 0.762 | 0.646 | 0.577 | 0.669 | 0.965 | 0.863 | 0.951 | 0.751 | 0.698 | 0.806 | 0.729 | 0.607 | 0.694 | 0.767 | 0.839 | 0.866 | -| GANomaly | | 0.421 | 0.203 | 0.404 | 0.413 | 0.408 | 0.744 | 0.251 | 0.457 | 0.682 | 0.537 | 0.270 | 0.472 | 0.231 | 0.372 | 0.440 | 0.434 | +| Model | | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper | +| --------------- | -------------- | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :------: | :-------: | :-------: | :-------: | :--------: | :--------: | :-------: | +| **EfficientAD** | **PDN-S** | **0.982** | 0.982 | **1.000** | 0.997 | **1.000** | 0.986 | **1.000** | 0.952 | 0.950 | 0.952 | 0.979 | **0.987** | 0.960 | 0.997 | 0.999 | **0.994** | +| EfficientAD | PDN-M | 0.975 | 0.972 | 0.998 | **1.000** | 0.999 | 0.984 | 0.991 | 0.945 | 0.957 | 0.948 | 0.989 | 0.926 | **0.975** | **1.000** | 0.965 | 0.971 | +| PatchCore | Wide ResNet-50 | 0.980 | 0.984 | 0.959 | 1.000 | **1.000** | 0.989 | 1.000 | **0.990** | **0.982** | 1.000 | 0.994 | 0.924 | 0.960 | 0.933 | **1.000** | 0.982 | +| PatchCore | ResNet-18 | 0.973 | 0.970 | 0.947 | 1.000 | 0.997 | 0.997 | 1.000 | 0.986 | 0.965 | 1.000 | 0.991 | 0.916 | 0.943 | 0.931 | 0.996 | 0.953 | +| CFlow | Wide ResNet-50 | 0.962 | 0.986 | 0.962 | **1.000** | 0.999 | 0.993 | **1.0** | 0.893 | 0.945 | **1.0** | **0.995** | 0.924 | 0.908 | 0.897 | 0.943 | 0.984 | +| CFA | Wide ResNet-50 | 0.956 | 0.978 | 0.961 | 0.990 | 0.999 | 0.994 | 0.998 | 0.979 | 0.872 | 1.000 | **0.995** | 0.946 | 0.703 | **1.000** | 0.957 | 0.967 | +| CFA | ResNet-18 | 0.930 | 0.953 | 0.947 | 0.999 | 1.000 | **1.000** | 0.991 | 0.947 | 0.858 | 0.995 | 0.932 | 0.887 | 0.625 | 0.994 | 0.895 | 0.919 | +| PaDiM | Wide ResNet-50 | 0.950 | **0.995** | 0.942 | **1.000** | 0.974 | 0.993 | 0.999 | 0.878 | 0.927 | 0.964 | 0.989 | 0.939 | 0.845 | 0.942 | 0.976 | 0.882 | +| PaDiM | ResNet-18 | 0.891 | 0.945 | 0.857 | 0.982 | 0.950 | 0.976 | 0.994 | 0.844 | 0.901 | 0.750 | 0.961 | 0.863 | 0.759 | 0.889 | 0.920 | 0.780 | +| DFM | Wide ResNet-50 | 0.943 | 0.855 | 0.784 | 0.997 | 0.995 | 0.975 | 0.999 | 0.969 | 0.924 | 0.978 | 0.939 | 0.962 | 0.873 | 0.969 | 0.971 | 0.961 | +| DFM | ResNet-18 | 0.936 | 0.817 | 0.736 | 0.993 | 0.966 | 0.977 | 1.000 | 0.956 | 0.944 | 0.994 | 0.922 | 0.961 | 0.89 | 0.969 | 0.939 | 0.969 | +| STFPM | Wide ResNet-50 | 0.876 | 0.957 | 0.977 | 0.981 | 0.976 | 0.939 | 0.987 | 0.878 | 0.732 | 0.995 | 0.973 | 0.652 | 0.825 | 0.500 | 0.875 | 0.899 | +| STFPM | ResNet-18 | 0.893 | 0.954 | **0.982** | 0.989 | 0.949 | 0.961 | 0.979 | 0.838 | 0.759 | 0.999 | 0.956 | 0.705 | 0.835 | **0.997** | 0.853 | 0.645 | +| DFKDE | Wide ResNet-50 | 0.774 | 0.708 | 0.422 | 0.905 | 0.959 | 0.903 | 0.936 | 0.746 | 0.853 | 0.736 | 0.687 | 0.749 | 0.574 | 0.697 | 0.843 | 0.892 | +| DFKDE | ResNet-18 | 0.762 | 0.646 | 0.577 | 0.669 | 0.965 | 0.863 | 0.951 | 0.751 | 0.698 | 0.806 | 0.729 | 0.607 | 0.694 | 0.767 | 0.839 | 0.866 | +| GANomaly | | 0.421 | 0.203 | 0.404 | 0.413 | 0.408 | 0.744 | 0.251 | 0.457 | 0.682 | 0.537 | 0.270 | 0.472 | 0.231 | 0.372 | 0.440 | 0.434 | ## Pixel-Level AUC -| Model | | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper | -| --------- | ------------------ | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :--------: | :--------: | :-------: | -| **CFA** | **Wide ResNet-50** | **0.983** | 0.980 | 0.954 | 0.989 | **0.985** | **0.974** | **0.989** | **0.988** | **0.989** | 0.985 | **0.992** | **0.988** | 0.979 | **0.991** | 0.977 | **0.990** | -| CFA | ResNet-18 | 0.979 | 0.970 | 0.973 | 0.992 | 0.978 | 0.964 | 0.986 | 0.984 | 0.987 | 0.987 | 0.981 | 0.981 | 0.973 | 0.990 | 0.964 | 0.978 | -| PatchCore | Wide ResNet-50 | 0.980 | 0.988 | 0.968 | 0.991 | 0.961 | 0.934 | 0.984 | **0.988** | 0.988 | 0.987 | 0.989 | 0.980 | **0.989** | 0.988 | **0.981** | 0.983 | -| PatchCore | ResNet-18 | 0.976 | 0.986 | 0.955 | 0.990 | 0.943 | 0.933 | 0.981 | 0.984 | 0.986 | 0.986 | 0.986 | 0.974 | 0.991 | 0.988 | 0.974 | 0.983 | -| CFlow | Wide ResNet-50 | 0.971 | 0.986 | 0.968 | 0.993 | 0.968 | 0.924 | 0.981 | 0.955 | 0.988 | **0.990** | 0.982 | 0.983 | 0.979 | 0.985 | 0.897 | 0.980 | -| PaDiM | Wide ResNet-50 | 0.979 | **0.991** | 0.970 | 0.993 | 0.955 | 0.957 | 0.985 | 0.970 | 0.988 | 0.985 | 0.982 | 0.966 | 0.988 | **0.991** | 0.976 | 0.986 | -| PaDiM | ResNet-18 | 0.968 | 0.984 | 0.918 | **0.994** | 0.934 | 0.947 | 0.983 | 0.965 | 0.984 | 0.978 | 0.970 | 0.957 | 0.978 | 0.988 | 0.968 | 0.979 | -| STFPM | Wide ResNet-50 | 0.903 | 0.987 | **0.989** | 0.980 | 0.966 | 0.956 | 0.966 | 0.913 | 0.956 | 0.974 | 0.961 | 0.946 | 0.988 | 0.178 | 0.807 | 0.980 | -| STFPM | ResNet-18 | 0.951 | 0.986 | 0.988 | 0.991 | 0.946 | 0.949 | 0.971 | 0.898 | 0.962 | 0.981 | 0.942 | 0.878 | 0.983 | 0.983 | 0.838 | 0.972 | +| Model | | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper | +| ----------- | ------------------ | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :--------: | :--------: | :-------: | +| **CFA** | **Wide ResNet-50** | **0.983** | 0.980 | 0.954 | 0.989 | **0.985** | **0.974** | **0.989** | **0.988** | **0.989** | 0.985 | **0.992** | **0.988** | 0.979 | **0.991** | 0.977 | **0.990** | +| CFA | ResNet-18 | 0.979 | 0.970 | 0.973 | 0.992 | 0.978 | 0.964 | 0.986 | 0.984 | 0.987 | 0.987 | 0.981 | 0.981 | 0.973 | 0.990 | 0.964 | 0.978 | +| PatchCore | Wide ResNet-50 | 0.980 | 0.988 | 0.968 | 0.991 | 0.961 | 0.934 | 0.984 | **0.988** | 0.988 | 0.987 | 0.989 | 0.980 | **0.989** | 0.988 | **0.981** | 0.983 | +| PatchCore | ResNet-18 | 0.976 | 0.986 | 0.955 | 0.990 | 0.943 | 0.933 | 0.981 | 0.984 | 0.986 | 0.986 | 0.986 | 0.974 | 0.991 | 0.988 | 0.974 | 0.983 | +| CFlow | Wide ResNet-50 | 0.971 | 0.986 | 0.968 | 0.993 | 0.968 | 0.924 | 0.981 | 0.955 | 0.988 | **0.990** | 0.982 | 0.983 | 0.979 | 0.985 | 0.897 | 0.980 | +| PaDiM | Wide ResNet-50 | 0.979 | **0.991** | 0.970 | 0.993 | 0.955 | 0.957 | 0.985 | 0.970 | 0.988 | 0.985 | 0.982 | 0.966 | 0.988 | **0.991** | 0.976 | 0.986 | +| PaDiM | ResNet-18 | 0.968 | 0.984 | 0.918 | **0.994** | 0.934 | 0.947 | 0.983 | 0.965 | 0.984 | 0.978 | 0.970 | 0.957 | 0.978 | 0.988 | 0.968 | 0.979 | +| EfficientAD | PDN-S | 0.960 | 0.963 | 0.937 | 0.976 | 0.907 | 0.868 | 0.983 | 0.983 | 0.980 | 0.976 | 0.978 | 0.986 | 0.985 | 0.962 | 0.956 | 0.961 | +| EfficientAD | PDN-M | 0.957 | 0.948 | 0.937 | 0.976 | 0.906 | 0.867 | 0.976 | 0.986 | 0.957 | 0.977 | 0.984 | 0.978 | 0.986 | 0.964 | 0.947 | 0.960 | +| STFPM | Wide ResNet-50 | 0.903 | 0.987 | **0.989** | 0.980 | 0.966 | 0.956 | 0.966 | 0.913 | 0.956 | 0.974 | 0.961 | 0.946 | 0.988 | 0.178 | 0.807 | 0.980 | +| STFPM | ResNet-18 | 0.951 | 0.986 | 0.988 | 0.991 | 0.946 | 0.949 | 0.971 | 0.898 | 0.962 | 0.981 | 0.942 | 0.878 | 0.983 | 0.983 | 0.838 | 0.972 | ## Image F1 Score @@ -322,6 +327,8 @@ Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License | ------------- | ------------------ | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :--------: | :--------: | :-------: | | **PatchCore** | **Wide ResNet-50** | **0.976** | 0.971 | 0.974 | **1.000** | **1.000** | 0.967 | **1.000** | 0.968 | **0.982** | **1.000** | 0.984 | 0.940 | 0.943 | 0.938 | **1.000** | **0.979** | | PatchCore | ResNet-18 | 0.970 | 0.949 | 0.946 | **1.000** | 0.98 | 0.992 | **1.000** | **0.978** | 0.969 | **1.000** | **0.989** | 0.940 | 0.932 | 0.935 | 0.974 | 0.967 | +| EfficientAD | PDN-S | 0.970 | 0.966 | **1.000** | 0.995 | **1.000** | 0.975 | **1.000** | 0.907 | 0.956 | 0.897 | 0.978 | 0.982 | 0.944 | 0.984 | 0.988 | 0.983 | +| EfficientAD | PDN-M | 0.966 | 0.977 | 0.991 | **1.000** | 0.994 | 0.967 | 0.984 | 0.922 | 0.969 | 0.884 | 0.984 | 0.952 | 0.955 | 1.000 | 0.929 | 0.979 | | CFA | Wide ResNet-50 | 0.962 | 0.961 | 0.957 | 0.995 | 0.994 | 0.983 | 0.984 | 0.962 | 0.946 | **1.000** | 0.984 | **0.952** | 0.855 | **1.000** | 0.907 | 0.975 | | CFA | ResNet-18 | 0.946 | 0.956 | 0.946 | 0.973 | **1.000** | **1.000** | 0.983 | 0.907 | 0.938 | 0.996 | 0.958 | 0.920 | 0.858 | 0.984 | 0.795 | 0.949 | | CFlow | Wide ResNet-50 | 0.944 | 0.972 | 0.932 | **1.000** | 0.988 | 0.967 | **1.000** | 0.832 | 0.939 | **1.000** | 0.979 | 0.924 | **0.971** | 0.870 | 0.818 | 0.967 | diff --git a/src/anomalib/models/efficientad/README.md b/src/anomalib/models/efficientad/README.md index 4d688dfd2d..d93afddd22 100644 --- a/src/anomalib/models/efficientad/README.md +++ b/src/anomalib/models/efficientad/README.md @@ -14,7 +14,7 @@ Features are extracted from a pre-trained teacher model and used to train a stud ### Anomaly Detection -Anomalies are detected as the difference in output feature maps between the student model and the autoencoder model. +Anomalies are detected as the difference in output feature maps between the teacher model, the student model and the autoencoder model. ## Usage @@ -28,12 +28,14 @@ All results gathered with seed `42`. ### Image-Level AUC -| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper | -| ------------------------ | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: | -| Distilled Teacher Medium | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | +| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper | +| ------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: | +| EfficientAD-S | 0.982 | 0.982 | 1.000 | 0.997 | 1.000 | 0.986 | 1.000 | 0.952 | 0.950 | 0.952 | 0.979 | 0.987 | 0.960 | 0.997 | 0.999 | 0.994 | +| EfficientAD-M | 0.975 | 0.972 | 0.998 | 1.000 | 0.999 | 0.984 | 0.991 | 0.945 | 0.957 | 0.948 | 0.989 | 0.926 | 0.975 | 1.000 | 0.965 | 0.971 | ### Image F1 Score -| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper | -| ------------------------ | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: | -| Distilled Teacher Medium | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | +| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper | +| ------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: | +| EfficientAD-S | 0.970 | 0.966 | 1.000 | 0.995 | 1.000 | 0.975 | 1.000 | 0.907 | 0.956 | 0.897 | 0.978 | 0.982 | 0.944 | 0.984 | 0.988 | 0.983 | +| EfficientAD-M | 0.966 | 0.977 | 0.991 | 1.000 | 0.994 | 0.967 | 0.984 | 0.922 | 0.969 | 0.884 | 0.984 | 0.952 | 0.955 | 1.000 | 0.929 | 0.979 | diff --git a/src/anomalib/models/efficientad/config.yaml b/src/anomalib/models/efficientad/config.yaml index dc654b4cd1..706591fb36 100644 --- a/src/anomalib/models/efficientad/config.yaml +++ b/src/anomalib/models/efficientad/config.yaml @@ -21,10 +21,11 @@ dataset: model: name: efficientad teacher_out_channels: 384 - model_size: medium # options: [small, medium] + model_size: small # options: [small, medium] lr: 0.0001 weight_decay: 0.00001 - padding: true + padding: false + pad_maps: true # relevant for "padding: false", see EfficientAD in lightning_model.py # generic params normalization_method: min_max # options: [null, min_max, cdf] @@ -73,7 +74,7 @@ trainer: accumulate_grad_batches: 1 max_epochs: 200 min_epochs: null - max_steps: -1 + max_steps: 70000 min_steps: null max_time: null limit_train_batches: 1.0 diff --git a/src/anomalib/models/efficientad/lightning_model.py b/src/anomalib/models/efficientad/lightning_model.py index f207eaa366..382595b96b 100644 --- a/src/anomalib/models/efficientad/lightning_model.py +++ b/src/anomalib/models/efficientad/lightning_model.py @@ -60,6 +60,8 @@ class EfficientAD(AnomalyModule): lr (float): learning rate weight_decay (float): optimizer weight decay padding (bool): use padding in convoluional layers + pad_maps (bool): relevant if padding is set to False. In this case, pad_maps = True pads the + output anomaly maps so that their size matches the size in the padding = True case. batch_size (int): batch size for imagenet dataloader """ @@ -67,10 +69,11 @@ def __init__( self, teacher_out_channels: int, image_size: tuple[int, int], - model_size: EfficientADModelSize = EfficientADModelSize.M, + model_size: EfficientADModelSize = EfficientADModelSize.S, lr: float = 0.0001, weight_decay: float = 0.00001, padding: bool = False, + pad_maps: bool = True, batch_size: int = 1, ) -> None: super().__init__() @@ -81,6 +84,7 @@ def __init__( input_size=image_size, model_size=model_size, padding=padding, + pad_maps=pad_maps, ) self.batch_size = batch_size self.image_size = image_size @@ -120,7 +124,7 @@ def prepare_imagenette_data(self) -> None: @torch.no_grad() def teacher_channel_mean_std(self, dataloader: DataLoader) -> dict[str, Tensor]: - """Calculate the the mean and std of the teacher models activations. + """Calculate the mean and std of the teacher models activations. Args: dataloader (DataLoader): Dataloader of the respective dataset. @@ -170,16 +174,43 @@ def map_norm_quantiles(self, dataloader: DataLoader) -> dict[str, Tensor]: map_ae = output["map_ae"] maps_st.append(map_st) maps_ae.append(map_ae) - maps_st = torch.cat(maps_st) - maps_ae = torch.cat(maps_ae) - qa_st = torch.quantile(maps_st, q=0.9).to(self.device) - qb_st = torch.quantile(maps_st, q=0.995).to(self.device) - qa_ae = torch.quantile(maps_ae, q=0.9).to(self.device) - qb_ae = torch.quantile(maps_ae, q=0.995).to(self.device) + + qa_st, qb_st = self._get_quantiles_of_maps(maps_st) + qa_ae, qb_ae = self._get_quantiles_of_maps(maps_ae) return {"qa_st": qa_st, "qa_ae": qa_ae, "qb_st": qb_st, "qb_ae": qb_ae} + def _get_quantiles_of_maps(self, maps: list[Tensor]) -> tuple[Tensor, Tensor]: + """Calculate 90% and 99.5% quantiles of the given anomaly maps. + + If the total number of elements in the given maps is larger than 16777216 + the returned quantiles are computed on a random subset of the given + elements. + + Args: + maps (list[Tensor]): List of anomaly maps. + + Returns: + tuple[Tensor, Tensor]: Two scalars - the 90% and the 99.5% quantile. + """ + maps_flat = torch.flatten(torch.cat(maps)) + # torch.quantile only works with input size up to 2**24 elements, see + # https://github.com/pytorch/pytorch/blob/b9f81a483a7879cd3709fd26bcec5f1ee33577e6/aten/src/ATen/native/Sorting.cpp#L291 + # if we have more elements we need to decrease the size + # we do this by sampling random elements of maps_flat because then + # the locations of the quantiles (90% and 99.5%) will still be + # valid even though they might not be the exact quantiles. + max_input_size = 2**24 + if len(maps_flat) > max_input_size: + # select a random subset with max_input_size elements. + perm = torch.randperm(len(maps_flat), device=self.device) + idx = perm[:max_input_size] + maps_flat = maps_flat[idx] + qa = torch.quantile(maps_flat, q=0.9).to(self.device) + qb = torch.quantile(maps_flat, q=0.995).to(self.device) + return qa, qb + def configure_optimizers(self) -> optim.Optimizer: - optimizer = optim.AdamW( + optimizer = optim.Adam( list(self.model.student.parameters()) + list(self.model.ae.parameters()), lr=self.lr, weight_decay=self.weight_decay, @@ -197,7 +228,7 @@ def on_train_start(self) -> None: self.model.mean_std.update(channel_mean_std) def training_step(self, batch: dict[str, str | Tensor], *args, **kwargs) -> dict[str, Tensor]: - """Training step for EfficintAD returns the student, autoencoder and combined loss. + """Training step for EfficientAD returns the student, autoencoder and combined loss. Args: batch (batch: dict[str, str | Tensor]): Batch containing image filename, image, label and mask @@ -228,7 +259,7 @@ def on_validation_start(self) -> None: Calculate the feature map quantiles of the validation dataset and push to the model. """ if (self.current_epoch + 1) == self.trainer.max_epochs: - map_norm_quantiles = self.map_norm_quantiles(self.trainer.datamodule.train_dataloader()) + map_norm_quantiles = self.map_norm_quantiles(self.trainer.datamodule.val_dataloader()) self.model.quantiles.update(map_norm_quantiles) def validation_step(self, batch: dict[str, str | Tensor], *args, **kwargs) -> STEP_OUTPUT: @@ -261,6 +292,7 @@ def __init__(self, hparams: DictConfig | ListConfig) -> None: lr=hparams.model.lr, weight_decay=hparams.model.weight_decay, padding=hparams.model.padding, + pad_maps=hparams.model.pad_maps, image_size=hparams.dataset.image_size, batch_size=hparams.dataset.train_batch_size, ) diff --git a/src/anomalib/models/efficientad/torch_model.py b/src/anomalib/models/efficientad/torch_model.py index 2b09e3644f..6b42275e23 100644 --- a/src/anomalib/models/efficientad/torch_model.py +++ b/src/anomalib/models/efficientad/torch_model.py @@ -193,6 +193,8 @@ class EfficientADModel(nn.Module): input_size (tuple): size of input images model_size (str): size of student and teacher model padding (bool): use padding in convoluional layers + pad_maps (bool): relevant if padding is set to False. In this case, pad_maps = True pads the + output anomaly maps so that their size matches the size in the padding = True case. device (str): which device the model should be loaded on """ @@ -200,11 +202,13 @@ def __init__( self, teacher_out_channels: int, input_size: tuple[int, int], - model_size: EfficientADModelSize = EfficientADModelSize.M, + model_size: EfficientADModelSize = EfficientADModelSize.S, padding=False, + pad_maps=True, ) -> None: super().__init__() + self.pad_maps = pad_maps self.teacher: PDN_M | PDN_S self.student: PDN_M | PDN_S @@ -308,6 +312,9 @@ def forward(self, batch: Tensor, batch_imagenet: Tensor = None) -> Tensor | dict (ae_output - student_output[:, self.teacher_out_channels :]) ** 2, dim=1, keepdim=True ) + if self.pad_maps: + map_st = F.pad(map_st, (4, 4, 4, 4)) + map_stae = F.pad(map_stae, (4, 4, 4, 4)) map_st = F.interpolate(map_st, size=(self.input_size[0], self.input_size[1]), mode="bilinear") map_stae = F.interpolate(map_stae, size=(self.input_size[0], self.input_size[1]), mode="bilinear")