From 1276e272a136f9ba5c040def0d42a9a28018f09e Mon Sep 17 00:00:00 2001 From: Om Doiphode Date: Sun, 18 Aug 2024 10:17:55 +0530 Subject: [PATCH] Removed deepforest_docs folder Use models from Huggingface Clean up the tests Update style for tests Switch to opencv-python-headless Using opencv-python was causing installation conflicts with albumentations which relies on the headless version. We were only using the GUI aspects of opencv in a single, debugging only, function. So this switches that function to matplotlib and updates the dependencies. Fixes #464. Don't install opencv via apt for testing We want to test the Python installs so we shouldn't be doing an apt install. This is particularly true since we've switched to opencv-python-headless as the dependency. --- .github/workflows/Conda-app.yml | 1 - deepforest/utilities.py | 160 ++++------- deepforest/visualize.py | 4 +- dev_requirements.txt | 3 +- .../deepforest_docs => }/CONTRIBUTING.md | 0 docs/Makefile | 2 +- docs/{source => }/_templates/navbar2.html | 0 .../advanced_features/CropModels.md | 0 .../advanced_features/ExtendingModule.md | 0 .../advanced_features/Model_Architecture.md | 0 .../advanced_features/index.rst | 0 .../advanced_features/multi_species.md | 0 .../advanced_features/scaling.md | 0 .../advanced_features/visualizations.md | 0 docs/{source => }/conf.py | 8 +- .../data_annotation/annotation.md | 4 +- .../data_annotation/index.rst | 0 .../deepforest_docs => }/deepforestr.md | 0 docs/developer_resources/authors.rst | 1 + .../developer_resources/code_of_conduct.rst | 0 docs/developer_resources/history.rst | 1 + .../developer_resources/index.rst | 0 .../examples/Australia.ipynb | 0 .../examples/nest_detection.ipynb | 0 .../deepforest_docs => }/figures/Figure_1.png | Bin .../figures/TrainingData.png | Bin .../figures/batch_classification_loss.svg | 0 .../figures/batch_regression_loss.svg | 0 .../figures/iou_equation.png | Bin .../figures/output_17_1.png | Bin .../figures/output_23_1.png | Bin .../figures/output_28_0.png | Bin .../figures/output_32_0.png | Bin .../figures/output_65_0.png | Bin .../figures/output_70_0.png | Bin .../figures/output_71_0.png | Bin .../tree_predicted_bounding_boxes.jpeg | Bin .../figures/tree_predicted_labels.jpeg | Bin .../getting_started/Reading_and_Writing.md | 0 .../getting_started.md | 2 +- .../getting_started/index.rst | 0 .../getting_started/sample.ipynb | 0 docs/index.rst | 9 + .../ConfigurationFile.md | 0 .../installation_and_setup/index.rst | 0 .../installation_and_setup/installation.md | 0 .../installation_and_setup/prebuilt.md | 6 +- .../introduction/index.rst | 1 - .../introduction/landing.md | 2 +- docs/{source/deepforest_docs => }/make.bat | 0 .../related_work/index.rst | 0 .../deepforest_docs => }/related_work/use.md | 0 .../deepforest_docs => }/requirements.txt | 0 .../source => }/deepforest.data.rst | 0 .../source => }/deepforest.rst | 0 .../developer_resources/authors.rst | 1 - .../developer_resources/history.rst | 1 - .../getting_started/getting_started.md | 83 ------ docs/source/deepforest_docs/index.rst | 61 ---- docs/source/index.rst | 60 +++- .../{deepforest_docs/source => }/modules.rst | 0 .../training_and_evaluation/Evaluation.md | 0 .../training_and_evaluation/better.md | 4 +- .../training_and_evaluation/index.rst | 0 .../training_and_evaluation/training.md | 0 environment.yml | 3 + setup.py | 2 +- tests/conftest.py | 33 ++- tests/profile_dataset.py | 13 +- tests/profile_evaluate.py | 17 +- tests/profile_predict_file.py | 29 +- tests/test_FasterRCNN.py | 29 +- tests/test_IoU.py | 29 +- tests/test_callbacks.py | 22 +- tests/test_data.py | 10 +- tests/test_dataset.py | 112 ++++---- tests/test_download.py | 62 ++-- tests/test_environment.py | 2 +- tests/test_evaluate.py | 73 +++-- tests/test_main.py | 271 ++++++++++-------- tests/test_model.py | 24 +- tests/test_multiprocessing.py | 14 +- tests/test_preprocess.py | 112 +++++--- tests/test_retinanet.py | 40 +-- tests/test_utilities.py | 218 ++++++++------ tests/test_visualize.py | 42 +-- 86 files changed, 805 insertions(+), 766 deletions(-) rename docs/{source/deepforest_docs => }/CONTRIBUTING.md (100%) rename docs/{source => }/_templates/navbar2.html (100%) rename docs/{source/deepforest_docs => }/advanced_features/CropModels.md (100%) rename docs/{source/deepforest_docs => }/advanced_features/ExtendingModule.md (100%) rename docs/{source/deepforest_docs => }/advanced_features/Model_Architecture.md (100%) rename docs/{source/deepforest_docs => }/advanced_features/index.rst (100%) rename docs/{source/deepforest_docs => }/advanced_features/multi_species.md (100%) rename docs/{source/deepforest_docs => }/advanced_features/scaling.md (100%) rename docs/{source/deepforest_docs => }/advanced_features/visualizations.md (100%) rename docs/{source => }/conf.py (96%) rename docs/{source/deepforest_docs => }/data_annotation/annotation.md (99%) rename docs/{source/deepforest_docs => }/data_annotation/index.rst (100%) rename docs/{source/deepforest_docs => }/deepforestr.md (100%) create mode 100644 docs/developer_resources/authors.rst rename docs/{source/deepforest_docs => }/developer_resources/code_of_conduct.rst (100%) create mode 100644 docs/developer_resources/history.rst rename docs/{source/deepforest_docs => }/developer_resources/index.rst (100%) rename docs/{source/deepforest_docs => }/examples/Australia.ipynb (100%) rename docs/{source/deepforest_docs => }/examples/nest_detection.ipynb (100%) rename docs/{source/deepforest_docs => }/figures/Figure_1.png (100%) rename docs/{source/deepforest_docs => }/figures/TrainingData.png (100%) rename docs/{source/deepforest_docs => }/figures/batch_classification_loss.svg (100%) rename docs/{source/deepforest_docs => }/figures/batch_regression_loss.svg (100%) rename docs/{source/deepforest_docs => }/figures/iou_equation.png (100%) rename docs/{source/deepforest_docs => }/figures/output_17_1.png (100%) rename docs/{source/deepforest_docs => }/figures/output_23_1.png (100%) rename docs/{source/deepforest_docs => }/figures/output_28_0.png (100%) rename docs/{source/deepforest_docs => }/figures/output_32_0.png (100%) rename docs/{source/deepforest_docs => }/figures/output_65_0.png (100%) rename docs/{source/deepforest_docs => }/figures/output_70_0.png (100%) rename docs/{source/deepforest_docs => }/figures/output_71_0.png (100%) rename docs/{source/deepforest_docs => }/figures/tree_predicted_bounding_boxes.jpeg (100%) rename docs/{source/deepforest_docs => }/figures/tree_predicted_labels.jpeg (100%) rename docs/{source/deepforest_docs => }/getting_started/Reading_and_Writing.md (100%) rename docs/{source/deepforest_docs/introduction => getting_started}/getting_started.md (98%) rename docs/{source/deepforest_docs => }/getting_started/index.rst (100%) rename docs/{source/deepforest_docs => }/getting_started/sample.ipynb (100%) create mode 100644 docs/index.rst rename docs/{source/deepforest_docs => }/installation_and_setup/ConfigurationFile.md (100%) rename docs/{source/deepforest_docs => }/installation_and_setup/index.rst (100%) rename docs/{source/deepforest_docs => }/installation_and_setup/installation.md (100%) rename docs/{source/deepforest_docs => }/installation_and_setup/prebuilt.md (97%) rename docs/{source/deepforest_docs => }/introduction/index.rst (87%) rename docs/{source/deepforest_docs => }/introduction/landing.md (99%) rename docs/{source/deepforest_docs => }/make.bat (100%) rename docs/{source/deepforest_docs => }/related_work/index.rst (100%) rename docs/{source/deepforest_docs => }/related_work/use.md (100%) rename docs/{source/deepforest_docs => }/requirements.txt (100%) rename docs/source/{deepforest_docs/source => }/deepforest.data.rst (100%) rename docs/source/{deepforest_docs/source => }/deepforest.rst (100%) delete mode 100644 docs/source/deepforest_docs/developer_resources/authors.rst delete mode 100644 docs/source/deepforest_docs/developer_resources/history.rst delete mode 100644 docs/source/deepforest_docs/getting_started/getting_started.md delete mode 100644 docs/source/deepforest_docs/index.rst rename docs/source/{deepforest_docs/source => }/modules.rst (100%) rename docs/{source/deepforest_docs => }/training_and_evaluation/Evaluation.md (100%) rename docs/{source/deepforest_docs => }/training_and_evaluation/better.md (97%) rename docs/{source/deepforest_docs => }/training_and_evaluation/index.rst (100%) rename docs/{source/deepforest_docs => }/training_and_evaluation/training.md (100%) diff --git a/.github/workflows/Conda-app.yml b/.github/workflows/Conda-app.yml index 220587333..19f4e0069 100644 --- a/.github/workflows/Conda-app.yml +++ b/.github/workflows/Conda-app.yml @@ -36,7 +36,6 @@ jobs: run: | sudo apt-get update sudo apt-get install -y libgl1-mesa-glx libegl1-mesa - sudo apt-get install -y python3-opencv - name: Install Conda environment with Micromamba uses: mamba-org/setup-micromamba@v1 diff --git a/deepforest/utilities.py b/deepforest/utilities.py index 67f5fb4df..917795b39 100644 --- a/deepforest/utilities.py +++ b/deepforest/utilities.py @@ -20,6 +20,8 @@ from deepforest import _ROOT import json import urllib.request +from huggingface_hub import hf_hub_download +from huggingface_hub.utils._errors import RevisionNotFoundError, HfHubHTTPError def read_config(config_path): @@ -54,6 +56,50 @@ def update_to(self, b=1, bsize=1, tsize=None): self.update(b * bsize - self.n) +def fetch_model(save_dir, repo_id, model_filename, version="main"): + """Downloads a model from Hugging Face and saves it to a specified + directory. + + Parameters: + - save_dir (str): The directory where the model will be saved. + - repo_id (str): The ID of the Hugging Face repository (e.g., "weecology/deepforest-tree"). + - model_filename (str): The name of the model file in the repository (e.g., "NEON.pt"). + - version (str): The version or branch of the model to download (e.g., "main", "v1.0.0"). Default is "main". + + Returns: + - output_path (str): The path where the model is saved. + """ + # Ensure the save directory exists + os.makedirs(save_dir, exist_ok=True) + + # Define the output path + output_path = os.path.join(save_dir, model_filename) + + try: + # Download the model from Hugging Face + hf_hub_download( + repo_id=repo_id, + filename=model_filename, + local_dir=save_dir, + revision=version # Specify the version or branch of the model to download + ) + print(f"Model saved to: {output_path}") + except RevisionNotFoundError as e: + print(f"Error: {e}") + print( + f"Check that the file '{model_filename}' and revision '{version}' exist in the repository '{repo_id}'." + ) + except HfHubHTTPError as e: + print(f"HTTP Error: {e}") + print( + "There might be a problem with your internet connection or the file may not exist." + ) + except Exception as e: + print(f"An unexpected error occurred: {e}") + + return version, output_path + + def use_bird_release( save_dir=os.path.join(_ROOT, "data/"), prebuilt_model="bird", check_release=True): """ @@ -63,61 +109,10 @@ def use_bird_release( prebuilt_model: Currently only accepts "NEON", but could be expanded to include other prebuilt models. The local model will be called prebuilt_model.h5 on disk. check_release (logical): whether to check github for a model recent release. In cases where you are hitting the github API rate limit, set to False and any local model will be downloaded. If no model has been downloaded an error will raise. Returns: release_tag, output_path (str): path to downloaded model - """ - - # Naming based on pre-built model - output_path = os.path.join(save_dir, prebuilt_model + ".pt") - - if check_release: - # Find latest github tag release from the DeepLidar repo - _json = json.loads( - urllib.request.urlopen( - urllib.request.Request( - 'https://api.github.com/repos/Weecology/BirdDetector/releases/latest', - headers={'Accept': 'application/vnd.github.v3+json'}, - )).read()) - asset = _json['assets'][0] - url = asset['browser_download_url'] - - # Check the release tagged locally - try: - release_txt = pd.read_csv(save_dir + "current_bird_release.csv") - except BaseException: - release_txt = pd.DataFrame({"current_bird_release": [None]}) - - # Download the current release it doesn't exist - if not release_txt.current_bird_release[0] == _json["html_url"]: - - print("Downloading model from BirdDetector release {}, see {} for details". - format(_json["tag_name"], _json["html_url"])) - - with DownloadProgressBar(unit='B', - unit_scale=True, - miniters=1, - desc=url.split('/')[-1]) as t: - urllib.request.urlretrieve(url, - filename=output_path, - reporthook=t.update_to) - - print("Model was downloaded and saved to {}".format(output_path)) - - # record the release tag locally - release_txt = pd.DataFrame({"current_bird_release": [_json["html_url"]]}) - release_txt.to_csv(save_dir + "current_bird_release.csv") - else: - print("Model from BirdDetector Repo release {} was already downloaded. " - "Loading model from file.".format(_json["html_url"])) - - return _json["html_url"], output_path - else: - try: - release_txt = pd.read_csv(save_dir + "current_release.csv") - except BaseException: - raise ValueError("Check release argument is {}, but no release has been " - "previously downloaded".format(check_release)) - - return release_txt.current_release[0], output_path + return fetch_model(save_dir, + repo_id="weecology/deepforest-bird", + model_filename="bird.pt") def use_release( @@ -132,58 +127,9 @@ def use_release( Returns: release_tag, output_path (str): path to downloaded model """ - # Naming based on pre-built model - output_path = os.path.join(save_dir, prebuilt_model + ".pt") - - if check_release: - # Find latest github tag release from the DeepLidar repo - _json = json.loads( - urllib.request.urlopen( - urllib.request.Request( - 'https://api.github.com/repos/Weecology/DeepForest/releases/latest', - headers={'Accept': 'application/vnd.github.v3+json'}, - )).read()) - asset = _json['assets'][0] - url = asset['browser_download_url'] - - # Check the release tagged locally - try: - release_txt = pd.read_csv(save_dir + "current_release.csv") - except BaseException: - release_txt = pd.DataFrame({"current_release": [None]}) - - # Download the current release it doesn't exist - if not release_txt.current_release[0] == _json["html_url"]: - - print("Downloading model from DeepForest release {}, see {} " - "for details".format(_json["tag_name"], _json["html_url"])) - - with DownloadProgressBar(unit='B', - unit_scale=True, - miniters=1, - desc=url.split('/')[-1]) as t: - urllib.request.urlretrieve(url, - filename=output_path, - reporthook=t.update_to) - - print("Model was downloaded and saved to {}".format(output_path)) - - # record the release tag locally - release_txt = pd.DataFrame({"current_release": [_json["html_url"]]}) - release_txt.to_csv(save_dir + "current_release.csv") - else: - print("Model from DeepForest release {} was already downloaded. " - "Loading model from file.".format(_json["html_url"])) - - return _json["html_url"], output_path - else: - try: - release_txt = pd.read_csv(save_dir + "current_release.csv") - except BaseException: - raise ValueError("Check release argument is {}, but no release " - "has been previously downloaded".format(check_release)) - - return release_txt.current_release[0], output_path + return fetch_model(save_dir, + repo_id="weecology/deepforest-tree", + model_filename="NEON.pt") def read_pascal_voc(xml_path): @@ -836,4 +782,4 @@ def project_boxes(df, root_dir, transform=True): raise NotImplementedError( "This function is deprecated. Please use image_to_geo_coordinates instead.") - return df \ No newline at end of file + return df diff --git a/deepforest/visualize.py b/deepforest/visualize.py index c80fc19f4..e9d4e9872 100644 --- a/deepforest/visualize.py +++ b/deepforest/visualize.py @@ -2,6 +2,7 @@ import os import pandas as pd import matplotlib +from matplotlib import pyplot as plt from PIL import Image import numpy as np import pandas.api.types as ptypes @@ -31,8 +32,7 @@ def view_dataset(ds, savedir=None, color=None, thickness=1): if savedir: cv2.imwrite("{}/{}".format(savedir, image_path[0]), image) else: - cv2.imshow(image) - cv2.waitKey(0) + plt.imshow(image) def format_geometry(predictions, scores=True): diff --git a/dev_requirements.txt b/dev_requirements.txt index 18fb090f6..f1b17776c 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -7,6 +7,7 @@ docformatter docutils<0.18 pydata-sphinx-theme geopandas +huggingface_hub h5py matplotlib nbmake @@ -14,7 +15,7 @@ nbsphinx nbqa numpy numpydoc -opencv-python>=4.5.4 +opencv-python-headless>=4.5.4 pandas pillow>6.2.0 pip diff --git a/docs/source/deepforest_docs/CONTRIBUTING.md b/docs/CONTRIBUTING.md similarity index 100% rename from docs/source/deepforest_docs/CONTRIBUTING.md rename to docs/CONTRIBUTING.md diff --git a/docs/Makefile b/docs/Makefile index 5bbd3e43e..1821294f6 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -5,7 +5,7 @@ SPHINXOPTS = SPHINXBUILD = python -msphinx SPHINXPROJ = deepforest -SOURCEDIR = source +SOURCEDIR = . BUILDDIR = _build # Put it first so that "make" without argument is like "make help". diff --git a/docs/source/_templates/navbar2.html b/docs/_templates/navbar2.html similarity index 100% rename from docs/source/_templates/navbar2.html rename to docs/_templates/navbar2.html diff --git a/docs/source/deepforest_docs/advanced_features/CropModels.md b/docs/advanced_features/CropModels.md similarity index 100% rename from docs/source/deepforest_docs/advanced_features/CropModels.md rename to docs/advanced_features/CropModels.md diff --git a/docs/source/deepforest_docs/advanced_features/ExtendingModule.md b/docs/advanced_features/ExtendingModule.md similarity index 100% rename from docs/source/deepforest_docs/advanced_features/ExtendingModule.md rename to docs/advanced_features/ExtendingModule.md diff --git a/docs/source/deepforest_docs/advanced_features/Model_Architecture.md b/docs/advanced_features/Model_Architecture.md similarity index 100% rename from docs/source/deepforest_docs/advanced_features/Model_Architecture.md rename to docs/advanced_features/Model_Architecture.md diff --git a/docs/source/deepforest_docs/advanced_features/index.rst b/docs/advanced_features/index.rst similarity index 100% rename from docs/source/deepforest_docs/advanced_features/index.rst rename to docs/advanced_features/index.rst diff --git a/docs/source/deepforest_docs/advanced_features/multi_species.md b/docs/advanced_features/multi_species.md similarity index 100% rename from docs/source/deepforest_docs/advanced_features/multi_species.md rename to docs/advanced_features/multi_species.md diff --git a/docs/source/deepforest_docs/advanced_features/scaling.md b/docs/advanced_features/scaling.md similarity index 100% rename from docs/source/deepforest_docs/advanced_features/scaling.md rename to docs/advanced_features/scaling.md diff --git a/docs/source/deepforest_docs/advanced_features/visualizations.md b/docs/advanced_features/visualizations.md similarity index 100% rename from docs/source/deepforest_docs/advanced_features/visualizations.md rename to docs/advanced_features/visualizations.md diff --git a/docs/source/conf.py b/docs/conf.py similarity index 96% rename from docs/source/conf.py rename to docs/conf.py index 44ff215dd..22194ebb7 100644 --- a/docs/source/conf.py +++ b/docs/conf.py @@ -25,7 +25,7 @@ using [reticulate](https://rstudio.github.io/reticulate/) works. """ -file_obj = open('deepforest_docs/deepforestr.md', 'w') +file_obj = open('deepforestr.md', 'w') readme_url = 'https://raw.githubusercontent.com/weecology/deepforestr/main/README.md' file_obj.write(deepforestr_title) @@ -42,7 +42,7 @@ # Create copy of CONTRIBUTING.md contributing_url = "https://raw.githubusercontent.com/weecology/DeepForest/main/CONTRIBUTING.md" -contributing_source = "../../CONTRIBUTING.md" +contributing_source = "../CONTRIBUTING.md" if not os.path.exists(contributing_source): with urllib.request.urlopen(contributing_url) as response: @@ -52,7 +52,7 @@ # reading from file1 and writing to file2 with open(contributing_source, "r") as file1: - with open("deepforest_docs/CONTRIBUTING.md", "w") as file2: + with open("CONTRIBUTING.md", "w") as file2: file2.write(file1.read()) needs_sphinx = "1.8" @@ -78,7 +78,7 @@ templates_path = ['_templates'] # The master toctree document. -master_doc = 'deepforest_docs/index' +master_doc = 'index' # General information about the project. project = u'DeepForest' diff --git a/docs/source/deepforest_docs/data_annotation/annotation.md b/docs/data_annotation/annotation.md similarity index 99% rename from docs/source/deepforest_docs/data_annotation/annotation.md rename to docs/data_annotation/annotation.md index c8a9fd98d..ce6365d4c 100644 --- a/docs/source/deepforest_docs/data_annotation/annotation.md +++ b/docs/data_annotation/annotation.md @@ -4,12 +4,12 @@ Annotation is the most important part of machine learning projects. If you aren ## How should I annotate images? For quick annotations of a few images, we recommend using QGIS or ArcGIS. Either as project or unprojected data. Create a shapefile for each image. -![QGISannotation](../../../../www/QGIS_annotation.png) +![QGISannotation](../../www/QGIS_annotation.png) ### Label-studio For longer term projects, we recommend [label-studio](https://labelstud.io/) as an annotation platform. It has many useful features and is easy to set up. -![QGISannotation](../../../../www/label_studio.png) +![QGISannotation](../../www/label_studio.png) ## Do I need annotate all objects in my image? Yes! Object detection models use the non-annotated areas of an image as negative data. We know that it can be difficult to annotate all objects in an image, but non-annotation will cause the model *to ignore* objects that should be treated as positive samples, leading to poor model performance. diff --git a/docs/source/deepforest_docs/data_annotation/index.rst b/docs/data_annotation/index.rst similarity index 100% rename from docs/source/deepforest_docs/data_annotation/index.rst rename to docs/data_annotation/index.rst diff --git a/docs/source/deepforest_docs/deepforestr.md b/docs/deepforestr.md similarity index 100% rename from docs/source/deepforest_docs/deepforestr.md rename to docs/deepforestr.md diff --git a/docs/developer_resources/authors.rst b/docs/developer_resources/authors.rst new file mode 100644 index 000000000..7739272f9 --- /dev/null +++ b/docs/developer_resources/authors.rst @@ -0,0 +1 @@ +.. include:: ../../AUTHORS.rst diff --git a/docs/source/deepforest_docs/developer_resources/code_of_conduct.rst b/docs/developer_resources/code_of_conduct.rst similarity index 100% rename from docs/source/deepforest_docs/developer_resources/code_of_conduct.rst rename to docs/developer_resources/code_of_conduct.rst diff --git a/docs/developer_resources/history.rst b/docs/developer_resources/history.rst new file mode 100644 index 000000000..5f2e348f2 --- /dev/null +++ b/docs/developer_resources/history.rst @@ -0,0 +1 @@ +.. include:: ../../HISTORY.rst diff --git a/docs/source/deepforest_docs/developer_resources/index.rst b/docs/developer_resources/index.rst similarity index 100% rename from docs/source/deepforest_docs/developer_resources/index.rst rename to docs/developer_resources/index.rst diff --git a/docs/source/deepforest_docs/examples/Australia.ipynb b/docs/examples/Australia.ipynb similarity index 100% rename from docs/source/deepforest_docs/examples/Australia.ipynb rename to docs/examples/Australia.ipynb diff --git a/docs/source/deepforest_docs/examples/nest_detection.ipynb b/docs/examples/nest_detection.ipynb similarity index 100% rename from docs/source/deepforest_docs/examples/nest_detection.ipynb rename to docs/examples/nest_detection.ipynb diff --git a/docs/source/deepforest_docs/figures/Figure_1.png b/docs/figures/Figure_1.png similarity index 100% rename from docs/source/deepforest_docs/figures/Figure_1.png rename to docs/figures/Figure_1.png diff --git a/docs/source/deepforest_docs/figures/TrainingData.png b/docs/figures/TrainingData.png similarity index 100% rename from docs/source/deepforest_docs/figures/TrainingData.png rename to docs/figures/TrainingData.png diff --git a/docs/source/deepforest_docs/figures/batch_classification_loss.svg b/docs/figures/batch_classification_loss.svg similarity index 100% rename from docs/source/deepforest_docs/figures/batch_classification_loss.svg rename to docs/figures/batch_classification_loss.svg diff --git a/docs/source/deepforest_docs/figures/batch_regression_loss.svg b/docs/figures/batch_regression_loss.svg similarity index 100% rename from docs/source/deepforest_docs/figures/batch_regression_loss.svg rename to docs/figures/batch_regression_loss.svg diff --git a/docs/source/deepforest_docs/figures/iou_equation.png b/docs/figures/iou_equation.png similarity index 100% rename from docs/source/deepforest_docs/figures/iou_equation.png rename to docs/figures/iou_equation.png diff --git a/docs/source/deepforest_docs/figures/output_17_1.png b/docs/figures/output_17_1.png similarity index 100% rename from docs/source/deepforest_docs/figures/output_17_1.png rename to docs/figures/output_17_1.png diff --git a/docs/source/deepforest_docs/figures/output_23_1.png b/docs/figures/output_23_1.png similarity index 100% rename from docs/source/deepforest_docs/figures/output_23_1.png rename to docs/figures/output_23_1.png diff --git a/docs/source/deepforest_docs/figures/output_28_0.png b/docs/figures/output_28_0.png similarity index 100% rename from docs/source/deepforest_docs/figures/output_28_0.png rename to docs/figures/output_28_0.png diff --git a/docs/source/deepforest_docs/figures/output_32_0.png b/docs/figures/output_32_0.png similarity index 100% rename from docs/source/deepforest_docs/figures/output_32_0.png rename to docs/figures/output_32_0.png diff --git a/docs/source/deepforest_docs/figures/output_65_0.png b/docs/figures/output_65_0.png similarity index 100% rename from docs/source/deepforest_docs/figures/output_65_0.png rename to docs/figures/output_65_0.png diff --git a/docs/source/deepforest_docs/figures/output_70_0.png b/docs/figures/output_70_0.png similarity index 100% rename from docs/source/deepforest_docs/figures/output_70_0.png rename to docs/figures/output_70_0.png diff --git a/docs/source/deepforest_docs/figures/output_71_0.png b/docs/figures/output_71_0.png similarity index 100% rename from docs/source/deepforest_docs/figures/output_71_0.png rename to docs/figures/output_71_0.png diff --git a/docs/source/deepforest_docs/figures/tree_predicted_bounding_boxes.jpeg b/docs/figures/tree_predicted_bounding_boxes.jpeg similarity index 100% rename from docs/source/deepforest_docs/figures/tree_predicted_bounding_boxes.jpeg rename to docs/figures/tree_predicted_bounding_boxes.jpeg diff --git a/docs/source/deepforest_docs/figures/tree_predicted_labels.jpeg b/docs/figures/tree_predicted_labels.jpeg similarity index 100% rename from docs/source/deepforest_docs/figures/tree_predicted_labels.jpeg rename to docs/figures/tree_predicted_labels.jpeg diff --git a/docs/source/deepforest_docs/getting_started/Reading_and_Writing.md b/docs/getting_started/Reading_and_Writing.md similarity index 100% rename from docs/source/deepforest_docs/getting_started/Reading_and_Writing.md rename to docs/getting_started/Reading_and_Writing.md diff --git a/docs/source/deepforest_docs/introduction/getting_started.md b/docs/getting_started/getting_started.md similarity index 98% rename from docs/source/deepforest_docs/introduction/getting_started.md rename to docs/getting_started/getting_started.md index 34a68bb54..632607078 100644 --- a/docs/source/deepforest_docs/introduction/getting_started.md +++ b/docs/getting_started/getting_started.md @@ -21,7 +21,7 @@ img = model.predict_image(path=sample_image_path, return_plot=True) plt.imshow(img[:,:,::-1]) ``` -![](../../../../www/getting_started1.png) +![](../../www/getting_started1.png) ** please note that this video was made before the deepforest-pytorch -> deepforest name change. ** diff --git a/docs/source/deepforest_docs/getting_started/index.rst b/docs/getting_started/index.rst similarity index 100% rename from docs/source/deepforest_docs/getting_started/index.rst rename to docs/getting_started/index.rst diff --git a/docs/source/deepforest_docs/getting_started/sample.ipynb b/docs/getting_started/sample.ipynb similarity index 100% rename from docs/source/deepforest_docs/getting_started/sample.ipynb rename to docs/getting_started/sample.ipynb diff --git a/docs/index.rst b/docs/index.rst new file mode 100644 index 000000000..d173e86a3 --- /dev/null +++ b/docs/index.rst @@ -0,0 +1,9 @@ +:orphan: + +DeepForest +======================== + +.. toctree:: + :maxdepth: 2 + + source/index diff --git a/docs/source/deepforest_docs/installation_and_setup/ConfigurationFile.md b/docs/installation_and_setup/ConfigurationFile.md similarity index 100% rename from docs/source/deepforest_docs/installation_and_setup/ConfigurationFile.md rename to docs/installation_and_setup/ConfigurationFile.md diff --git a/docs/source/deepforest_docs/installation_and_setup/index.rst b/docs/installation_and_setup/index.rst similarity index 100% rename from docs/source/deepforest_docs/installation_and_setup/index.rst rename to docs/installation_and_setup/index.rst diff --git a/docs/source/deepforest_docs/installation_and_setup/installation.md b/docs/installation_and_setup/installation.md similarity index 100% rename from docs/source/deepforest_docs/installation_and_setup/installation.md rename to docs/installation_and_setup/installation.md diff --git a/docs/source/deepforest_docs/installation_and_setup/prebuilt.md b/docs/installation_and_setup/prebuilt.md similarity index 97% rename from docs/source/deepforest_docs/installation_and_setup/prebuilt.md rename to docs/installation_and_setup/prebuilt.md index b83d0ba7a..f5ffd7b93 100644 --- a/docs/source/deepforest_docs/installation_and_setup/prebuilt.md +++ b/docs/installation_and_setup/prebuilt.md @@ -6,7 +6,7 @@ DeepForest has two prebuilt models: Bird Detection and Tree Crown Detection. The model was initially described in [Remote Sensing](https://www.mdpi.com/2072-4292/11/11/1309) on a single site. The prebuilt model uses a semi-supervised approach in which millions of moderate quality annotations are generated using a LiDAR unsupervised tree detection algorithm, followed by hand-annotations of RGB imagery from select sites. Comparisons among geographic sites were added to [Ecological Informatics](https://www.sciencedirect.com/science/article/pii/S157495412030011X). The model was further improved, and the Python package was released in [Methods in Ecology and Evolution](https://besjournals.onlinelibrary.wiley.com/doi/full/10.1111/2041-210X.13472). -![image](../../../../www/MEE_Figure4.png) +![image](../../www/MEE_Figure4.png) ### Citation > Weinstein, B.G.; Marconi, S.; Bohlman, S.; Zare, A.; White, E. Individual Tree-Crown Detection in RGB Imagery Using Semi-Supervised Deep Learning Neural Networks. Remote Sens. 2019, 11, 1309 @@ -23,7 +23,7 @@ The model was initially described in [Ecological Applications](https://esajourna ### Citation > Weinstein, B.G., Garner, L., Saccomanno, V.R., Steinkraus, A., Ortega, A., Brush, K., Yenni, G., McKellar, A.E., Converse, R., Lippitt, C.D., Wegmann, A., Holmes, N.D., Edney, A.J., Hart, T., Jessopp, M.J., Clarke, R.H., Marchowski, D., Senyondo, H., Dotson, R., White, E.P., Frederick, P. and Ernest, S.K.M. (2022), A general deep learning model for bird detection in high resolution airborne imagery. Ecological Applications. Accepted Author Manuscript e2694. https://doi-org.lp.hscl.ufl.edu/10.1002/eap.2694 -![image](../../../../www/example_predictions_small.png) +![image](../../www/example_predictions_small.png) ``` #Load deepforest model and set bird label @@ -31,7 +31,7 @@ m = main.deepforest(label_dict={"Bird":0}) m.use_bird_release() ``` -![](../../../../www/bird_panel.jpg) +![](../../www/bird_panel.jpg) We have created a [GPU colab tutorial](https://colab.research.google.com/drive/1e9_pZM0n_v3MkZpSjVRjm55-LuCE2IYE?usp=sharing ) to demonstrate the workflow for using the bird model. diff --git a/docs/source/deepforest_docs/introduction/index.rst b/docs/introduction/index.rst similarity index 87% rename from docs/source/deepforest_docs/introduction/index.rst rename to docs/introduction/index.rst index 301248681..5e6b44bbb 100644 --- a/docs/source/deepforest_docs/introduction/index.rst +++ b/docs/introduction/index.rst @@ -8,4 +8,3 @@ Welcome to the introduction section. :caption: Contents: landing - getting_started diff --git a/docs/source/deepforest_docs/introduction/landing.md b/docs/introduction/landing.md similarity index 99% rename from docs/source/deepforest_docs/introduction/landing.md rename to docs/introduction/landing.md index 5b7e0108d..284fc1603 100644 --- a/docs/source/deepforest_docs/introduction/landing.md +++ b/docs/introduction/landing.md @@ -2,7 +2,7 @@ DeepForest is a python package for training and predicting ecological objects in airborne imagery. DeepForest comes with models for immediate use and finetuning. Both are single class modules that can be extended to species classification based on new data. Users can extend these models by annotating and training custom models. -![](../../../../www/image.png) +![](../../www/image.png) ## How does deepforest work? diff --git a/docs/source/deepforest_docs/make.bat b/docs/make.bat similarity index 100% rename from docs/source/deepforest_docs/make.bat rename to docs/make.bat diff --git a/docs/source/deepforest_docs/related_work/index.rst b/docs/related_work/index.rst similarity index 100% rename from docs/source/deepforest_docs/related_work/index.rst rename to docs/related_work/index.rst diff --git a/docs/source/deepforest_docs/related_work/use.md b/docs/related_work/use.md similarity index 100% rename from docs/source/deepforest_docs/related_work/use.md rename to docs/related_work/use.md diff --git a/docs/source/deepforest_docs/requirements.txt b/docs/requirements.txt similarity index 100% rename from docs/source/deepforest_docs/requirements.txt rename to docs/requirements.txt diff --git a/docs/source/deepforest_docs/source/deepforest.data.rst b/docs/source/deepforest.data.rst similarity index 100% rename from docs/source/deepforest_docs/source/deepforest.data.rst rename to docs/source/deepforest.data.rst diff --git a/docs/source/deepforest_docs/source/deepforest.rst b/docs/source/deepforest.rst similarity index 100% rename from docs/source/deepforest_docs/source/deepforest.rst rename to docs/source/deepforest.rst diff --git a/docs/source/deepforest_docs/developer_resources/authors.rst b/docs/source/deepforest_docs/developer_resources/authors.rst deleted file mode 100644 index 84fb6da82..000000000 --- a/docs/source/deepforest_docs/developer_resources/authors.rst +++ /dev/null @@ -1 +0,0 @@ -.. include:: ../../../../AUTHORS.rst diff --git a/docs/source/deepforest_docs/developer_resources/history.rst b/docs/source/deepforest_docs/developer_resources/history.rst deleted file mode 100644 index 20cf9a864..000000000 --- a/docs/source/deepforest_docs/developer_resources/history.rst +++ /dev/null @@ -1 +0,0 @@ -.. include:: ../../../../HISTORY.rst diff --git a/docs/source/deepforest_docs/getting_started/getting_started.md b/docs/source/deepforest_docs/getting_started/getting_started.md deleted file mode 100644 index 34a68bb54..000000000 --- a/docs/source/deepforest_docs/getting_started/getting_started.md +++ /dev/null @@ -1,83 +0,0 @@ -# Getting started - -# Demo - -[Try out the DeepForest models online!](https://huggingface.co/spaces/weecology/deepforest-demo) - -## How do I use a pretrained model to predict an image? - -```python -from deepforest import main -from deepforest import get_data -import matplotlib.pyplot as plt - -model = main.deepforest() -model.use_release() - -sample_image_path = get_data("OSBS_029.png") -img = model.predict_image(path=sample_image_path, return_plot=True) - -#predict_image returns plot in BlueGreenRed (opencv style), but matplotlib likes RedGreenBlue, switch the channel order. Many functions in deepforest will automatically perform this flip for you and give a warning. -plt.imshow(img[:,:,::-1]) -``` - -![](../../../../www/getting_started1.png) - - -** please note that this video was made before the deepforest-pytorch -> deepforest name change. ** - -
- -For single images, ```predict_image``` can read an image from memory or file and return predicted bounding boxes. - -### Sample data - -DeepForest comes with a small set of sample data that can be used to test out the provided examples. The data resides in the DeepForest data directory. Use the `get_data` helper function to locate the path to this directory, if needed. - -```python -sample_image = get_data("OSBS_029.png") -sample_image -'/Users/benweinstein/Documents/DeepForest/deepforest/data/OSBS_029.png' -``` - -To use images other than those in the sample data directory, provide the full path for the images. - -```python -image_path = get_data("OSBS_029.png") -boxes = model.predict_image(path=image_path, return_plot = False) -``` - -``` ->>> boxes - xmin ymin xmax ymax label score image_path -0 330.0 342.0 373.0 391.0 Tree 0.802979 OSBS_029.png -1 216.0 206.0 248.0 242.0 Tree 0.778803 OSBS_029.png -2 325.0 44.0 363.0 82.0 Tree 0.751573 OSBS_029.png -3 261.0 238.0 296.0 276.0 Tree 0.748605 OSBS_029.png -4 173.0 0.0 229.0 33.0 Tree 0.738210 OSBS_029.png -5 258.0 198.0 291.0 230.0 Tree 0.716250 OSBS_029.png -6 97.0 305.0 152.0 363.0 Tree 0.711664 OSBS_029.png -7 52.0 72.0 85.0 108.0 Tree 0.698782 OSBS_029.png -``` - -### Predict a tile - -Large tiles covering wide geographic extents cannot fit into memory during prediction and would yield poor results due to the density of bounding boxes. Often provided as geospatial .tif files, remote sensing data is best suited for the ```predict_tile``` function, which splits the tile into overlapping windows, performs prediction on each of the windows, and then reassembles the resulting annotations. - -Let's show an example with a small image. For larger images, patch_size should be increased. - -```python -raster_path = get_data("OSBS_029.tif") -# Window size of 300px with an overlap of 25% among windows for this small tile. -predicted_raster = model.predict_tile(raster_path, return_plot = True, patch_size=300,patch_overlap=0.25) - -# View boxes overlayed when return_plot=True, when False, boxes are returned. -plt.imshow(predicted_raster) -plt.show() -``` - -** Please note the predict tile function is sensitive to patch_size, especially when using the prebuilt model on new data** - -We encourage users to try out a variety of patch sizes. For 0.1m data, 400-800px per window is appropriate, but it will depend on the density of tree plots. For coarser resolution tiles, >800px patch sizes have been effective, but we welcome feedback from users using a variety of spatial resolutions. - - diff --git a/docs/source/deepforest_docs/index.rst b/docs/source/deepforest_docs/index.rst deleted file mode 100644 index 9b04191a6..000000000 --- a/docs/source/deepforest_docs/index.rst +++ /dev/null @@ -1,61 +0,0 @@ -Welcome to DeepForest! -********************** - -DeepForest is a python package for airborne object detection and classification. - -**Tree crown prediction using DeepForest** - -.. image:: ../../../www/OSBS_sample.png - -**Bird detection using DeepForest** - -.. image:: ../../../www/bird_panel.jpg - - -Why DeepForest? -=============== - -Observing the abundance and distribution of individual organisms is one of the foundations of ecology. Connecting broad-scale changes in organismal ecology, such as those associated with climate change, shifts in land use, and invasive species require fine-grained data on individuals at wide spatial and temporal extents. - -To capture these data, ecologists are turning to airborne data collection from uncrewed aerial vehicles, piloted aircraft, and earth-facing satellites. Computer vision, a type of image-based artificial intelligence, has become a reliable tool for converting images into ecological information by detecting and classifying ecological objects within airborne imagery. - -There have been many studies demonstrating that, with sufficient labeling, computer vision can yield near-human-level performance for ecological analysis. However, almost all of these studies rely on isolated computer vision models that require extensive technical expertise and human data labeling. -In addition, the speed of innovation in computer vision makes it difficult for even experts to keep up with new innovations. - -To address these challenges, the next phase of ecological computer vision needs to reduce the technical barriers and move towards general models that can be applied across space, time, and taxa. - -DeepForest aims to be **simple**, **customizable**, and **modular**. DeepForest makes an effort to keep unnecessary complexity hidden from the ordinary user by developing straightforward functions like "predict_tile." The majority of machine learning projects actually fail due to poor data and project management, not clever models. DeepForest makes an effort to generate straightforward defaults, utilize already-existing labeling tools and interfaces, and minimize the effect of learning new APIs and code. - -Where can I get help, learn from others, and report bugs? -========================================================= - -Given the enormous array of forest types and image acquisition environments, it is unlikely that your image will be perfectly predicted by a prebuilt model. Below are some tips and some general guidelines to improve predictions. - -Get suggestions on how to improve a model by using the [discussion board](https://github.com/weecology/DeepForest/discussions). Please be aware that only feature requests or bug reports should be posted on the issues page. The most helpful thing you can do is leave feedback on DeepForest `issue page`_. No feature or issue, or positive affirmation is too small. Please do it now! - - -`Source code`_ is available on GitHub. - -.. toctree:: - :maxdepth: 2 - - introduction/index - installation_and_setup/index - getting_started/index - training_and_evaluation/index - data_annotation/index - advanced_features/index - developer_resources/index - deepforestr - source/modules.rst - related_work/index - - -Indices and tables -================== -* :ref:`genindex` -* :ref:`modindex` -* :ref:`search` - -.. _issue page: https://github.com/weecology/DeepForest/issues -.. _Source code: https://github.com/weecology/DeepForest.git diff --git a/docs/source/index.rst b/docs/source/index.rst index bb8aa2aac..59be501ea 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -1,9 +1,61 @@ -:orphan: +Welcome to DeepForest! +********************** -DeepForest -======================== +DeepForest is a python package for airborne object detection and classification. + +**Tree crown prediction using DeepForest** + +.. image:: ../../www/OSBS_sample.png + +**Bird detection using DeepForest** + +.. image:: ../../www/bird_panel.jpg + + +Why DeepForest? +=============== + +Observing the abundance and distribution of individual organisms is one of the foundations of ecology. Connecting broad-scale changes in organismal ecology, such as those associated with climate change, shifts in land use, and invasive species require fine-grained data on individuals at wide spatial and temporal extents. + +To capture these data, ecologists are turning to airborne data collection from uncrewed aerial vehicles, piloted aircraft, and earth-facing satellites. Computer vision, a type of image-based artificial intelligence, has become a reliable tool for converting images into ecological information by detecting and classifying ecological objects within airborne imagery. + +There have been many studies demonstrating that, with sufficient labeling, computer vision can yield near-human-level performance for ecological analysis. However, almost all of these studies rely on isolated computer vision models that require extensive technical expertise and human data labeling. +In addition, the speed of innovation in computer vision makes it difficult for even experts to keep up with new innovations. + +To address these challenges, the next phase of ecological computer vision needs to reduce the technical barriers and move towards general models that can be applied across space, time, and taxa. + +DeepForest aims to be **simple**, **customizable**, and **modular**. DeepForest makes an effort to keep unnecessary complexity hidden from the ordinary user by developing straightforward functions like "predict_tile." The majority of machine learning projects actually fail due to poor data and project management, not clever models. DeepForest makes an effort to generate straightforward defaults, utilize already-existing labeling tools and interfaces, and minimize the effect of learning new APIs and code. + +Where can I get help, learn from others, and report bugs? +========================================================= + +Given the enormous array of forest types and image acquisition environments, it is unlikely that your image will be perfectly predicted by a prebuilt model. Below are some tips and some general guidelines to improve predictions. + +Get suggestions on how to improve a model by using the [discussion board](https://github.com/weecology/DeepForest/discussions). Please be aware that only feature requests or bug reports should be posted on the issues page. The most helpful thing you can do is leave feedback on DeepForest `issue page`_. No feature or issue, or positive affirmation is too small. Please do it now! + + +`Source code`_ is available on GitHub. .. toctree:: :maxdepth: 2 - deepforest_docs/index + ../introduction/index + ../installation_and_setup/index + ../getting_started/index + ../training_and_evaluation/index + ../data_annotation/index + ../advanced_features/index + ../developer_resources/index + ../deepforestr + modules.rst + ../related_work/index + + +Indices and tables +================== +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` + +.. _issue page: https://github.com/weecology/DeepForest/issues +.. _Source code: https://github.com/weecology/DeepForest.git diff --git a/docs/source/deepforest_docs/source/modules.rst b/docs/source/modules.rst similarity index 100% rename from docs/source/deepforest_docs/source/modules.rst rename to docs/source/modules.rst diff --git a/docs/source/deepforest_docs/training_and_evaluation/Evaluation.md b/docs/training_and_evaluation/Evaluation.md similarity index 100% rename from docs/source/deepforest_docs/training_and_evaluation/Evaluation.md rename to docs/training_and_evaluation/Evaluation.md diff --git a/docs/source/deepforest_docs/training_and_evaluation/better.md b/docs/training_and_evaluation/better.md similarity index 97% rename from docs/source/deepforest_docs/training_and_evaluation/better.md rename to docs/training_and_evaluation/better.md index c9827a114..4840be82e 100644 --- a/docs/source/deepforest_docs/training_and_evaluation/better.md +++ b/docs/training_and_evaluation/better.md @@ -13,14 +13,14 @@ The prebuilt model was trained on 10cm data at 400px crops. The model is sensiti tile = model.predict_tile("/Users/ben/Desktop/test.jpg",return_plot=True,patch_overlap=0,iou_threshold=0.05,patch_size=400) ``` -![](../../../../www/example_patch400.png) +![](../../www/example_patch400.png) Acceptable, but not ideal. Here is 1000 px patches. -![](../../../../www/example_patch1000.png) +![](../../www/example_patch1000.png) improved. diff --git a/docs/source/deepforest_docs/training_and_evaluation/index.rst b/docs/training_and_evaluation/index.rst similarity index 100% rename from docs/source/deepforest_docs/training_and_evaluation/index.rst rename to docs/training_and_evaluation/index.rst diff --git a/docs/source/deepforest_docs/training_and_evaluation/training.md b/docs/training_and_evaluation/training.md similarity index 100% rename from docs/source/deepforest_docs/training_and_evaluation/training.md rename to docs/training_and_evaluation/training.md diff --git a/environment.yml b/environment.yml index 22ad1dd3c..bb6594c81 100644 --- a/environment.yml +++ b/environment.yml @@ -6,6 +6,7 @@ dependencies: - bumpversion - docutils<0.18 - geopandas + - huggingface_hub - h5py - matplotlib - nbmake @@ -47,3 +48,5 @@ dependencies: - nbqa - Pygments - docformatter + - opencv-python-headless + diff --git a/setup.py b/setup.py index 7957d9747..bb1933610 100644 --- a/setup.py +++ b/setup.py @@ -64,7 +64,7 @@ packages=find_packages(), include_package_data=True, install_requires=[ - "albumentations>=1.0.0", "aiolimiter", "aiohttp", "docformatter", "geopandas", "matplotlib", "nbqa", "numpy", + "albumentations>=1.0.0", "aiolimiter", "aiohttp", "docformatter", "huggingface_hub", "geopandas", "matplotlib", "nbqa", "numpy", "opencv-python>=4.5.4", "pandas", "Pillow>6.2.0", "progressbar2", "pycocotools", "pydata-sphinx-theme", "Pygments", "pytorch-lightning>=1.5.8", "rasterio", "recommonmark", "rtree", "scipy>1.5", "six", "slidingwindow", "sphinx", "supervision", "torch", "torchvision>=0.13", "tqdm", diff --git a/tests/conftest.py b/tests/conftest.py index a1c0a0082..61e6c1e0a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,4 @@ -#Fixtures model to only download model once +# Fixtures model to only download model once # download latest release import pytest from deepforest import utilities, main @@ -9,14 +9,16 @@ collect_ignore = ['setup.py'] + @pytest.fixture(scope="session") def config(): - config = utilities.read_config("{}/deepforest_config.yml".format(os.path.dirname(_ROOT))) + config = utilities.read_config("{}/deepforest_config.yml".format( + os.path.dirname(_ROOT))) config["fast_dev_run"] = True config["batch_size"] = True - return config + @pytest.fixture(scope="session") def download_release(): print("running fixtures") @@ -27,41 +29,42 @@ def download_release(): pass assert os.path.exists(get_data("NEON.pt")) + @pytest.fixture(scope="session") def ROOT(): return _ROOT + @pytest.fixture() def two_class_m(): - m = main.deepforest(num_classes=2,label_dict={"Alive":0,"Dead":1}) - m.config["train"]["csv_file"] = get_data("testfile_multi.csv") + m = main.deepforest(num_classes=2, label_dict={"Alive": 0, "Dead": 1}) + m.config["train"]["csv_file"] = get_data("testfile_multi.csv") m.config["train"]["root_dir"] = os.path.dirname(get_data("testfile_multi.csv")) m.config["train"]["fast_dev_run"] = True m.config["batch_size"] = 2 - - m.config["validation"]["csv_file"] = get_data("testfile_multi.csv") + m.config["validation"]["csv_file"] = get_data("testfile_multi.csv") m.config["validation"]["root_dir"] = os.path.dirname(get_data("testfile_multi.csv")) m.config["validation"]["val_accuracy_interval"] = 1 m.create_trainer() - + return m + @pytest.fixture() def m(download_release): m = main.deepforest() - m.config["train"]["csv_file"] = get_data("example.csv") + m.config["train"]["csv_file"] = get_data("example.csv") m.config["train"]["root_dir"] = os.path.dirname(get_data("example.csv")) m.config["train"]["fast_dev_run"] = True m.config["batch_size"] = 2 - - m.config["validation"]["csv_file"] = get_data("example.csv") + m.config["validation"]["csv_file"] = get_data("example.csv") m.config["validation"]["root_dir"] = os.path.dirname(get_data("example.csv")) - m.config["workers"] = 0 + m.config["workers"] = 0 m.config["validation"]["val_accuracy_interval"] = 1 m.config["train"]["epochs"] = 2 - + m.create_trainer() m.use_release(check_release=False) - - return m \ No newline at end of file + + return m diff --git a/tests/profile_dataset.py b/tests/profile_dataset.py index 9eecc5a40..37323cc71 100644 --- a/tests/profile_dataset.py +++ b/tests/profile_dataset.py @@ -1,26 +1,27 @@ -#Profile the dataset class +# Profile the dataset class from deepforest import dataset from deepforest import get_data import os import cProfile, pstats -def run(): + +def run(): csv_file = get_data("OSBS_029.csv") root_dir = os.path.dirname(csv_file) - + ds = dataset.TreeDataset(csv_file=csv_file, root_dir=root_dir, transforms=dataset.get_transform(augment=True)) - + for x in range(1000): next(iter(ds)) + if __name__ == "__main__": profiler = cProfile.Profile() profiler.enable() run() profiler.disable() stats = pstats.Stats(profiler).sort_stats('cumtime') - stats.print_stats() + stats.print_stats() stats.dump_stats('dataset.prof') - \ No newline at end of file diff --git a/tests/profile_evaluate.py b/tests/profile_evaluate.py index 181dab629..15e4ce684 100644 --- a/tests/profile_evaluate.py +++ b/tests/profile_evaluate.py @@ -1,4 +1,4 @@ -#Profile the dataset class +# Profile the dataset class from deepforest import evaluate from deepforest import main from deepforest import get_data @@ -6,18 +6,22 @@ import os import cProfile, pstats -def run(m): +def run(m): csv_file = get_data("OSBS_029.csv") predictions = m.predict_file(csv_file=csv_file, root_dir=os.path.dirname(csv_file)) predictions.label = "Tree" ground_truth = pd.read_csv(csv_file) - results = evaluate.evaluate(predictions=predictions, ground_df=ground_truth, root_dir=os.path.dirname(csv_file), savedir=None) - + results = evaluate.evaluate(predictions=predictions, + ground_df=ground_truth, + root_dir=os.path.dirname(csv_file), + savedir=None) + + if __name__ == "__main__": m = main.deepforest() m.use_release() - + profiler = cProfile.Profile() profiler.enable() m = main.deepforest() @@ -25,6 +29,5 @@ def run(m): run(m) profiler.disable() stats = pstats.Stats(profiler).sort_stats('cumtime') - stats.print_stats() + stats.print_stats() stats.dump_stats('evaluate.prof') - \ No newline at end of file diff --git a/tests/profile_predict_file.py b/tests/profile_predict_file.py index e018abc71..b707049ca 100644 --- a/tests/profile_predict_file.py +++ b/tests/profile_predict_file.py @@ -1,4 +1,4 @@ -#Profile the dataset class on gpu +# Profile the dataset class on gpu from deepforest import main from deepforest import get_data import os @@ -9,35 +9,38 @@ from PIL import Image import cv2 -def run(m, csv_file, root_dir): + +def run(m, csv_file, root_dir): predictions = m.predict_file(csv_file=csv_file, root_dir=root_dir) - + + if __name__ == "__main__": m = main.deepforest() m.use_release() m.config["workers"] = 0 m.config["batch_size"] = 5 - + csv_file = get_data("OSBS_029.csv") image_path = get_data("OSBS_029.png") - tmpdir = tempfile.gettempdir() - df = pd.read_csv(csv_file) - + tmpdir = tempfile.gettempdir() + df = pd.read_csv(csv_file) + big_frame = [] for x in range(100): - img = Image.open("{}/{}".format(os.path.dirname(csv_file), df.image_path.unique()[0])) + img = Image.open("{}/{}".format(os.path.dirname(csv_file), + df.image_path.unique()[0])) cv2.imwrite("{}/{}.png".format(tmpdir, x), np.array(img)) new_df = df.copy() new_df.image_path = "{}.png".format(x) big_frame.append(new_df) - + big_frame = pd.concat(big_frame) big_frame.to_csv("{}/annotations.csv".format(tmpdir)) - + profiler = cProfile.Profile() profiler.enable() - run(m, csv_file = "{}/annotations.csv".format(tmpdir), root_dir = tmpdir) + run(m, csv_file="{}/annotations.csv".format(tmpdir), root_dir=tmpdir) profiler.disable() stats = pstats.Stats(profiler).sort_stats('cumtime') - stats.print_stats() - stats.dump_stats('predict_file.prof') \ No newline at end of file + stats.print_stats() + stats.dump_stats('predict_file.prof') diff --git a/tests/test_FasterRCNN.py b/tests/test_FasterRCNN.py index 13cdca5df..316a4eec2 100644 --- a/tests/test_FasterRCNN.py +++ b/tests/test_FasterRCNN.py @@ -1,4 +1,4 @@ -#test FasterRCNN +# test FasterRCNN from deepforest.models import FasterRCNN from deepforest import get_data import pytest @@ -7,26 +7,30 @@ import torchvision import os -os.environ['KMP_DUPLICATE_LIB_OK']='True' +os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' -#Empty tester from https://github.com/datumbox/vision/blob/06ebee1a9f10c76d8ac5768fd578362dd5ace6e9/test/test_models_detection_negative_samples.py#L14 + +# Empty tester from https://github.com/datumbox/vision/blob/06ebee1a9f10c76d8ac5768fd578362dd5ace6e9/test/test_models_detection_negative_samples.py#L14 def _make_empty_sample(): images = [torch.rand((3, 100, 100), dtype=torch.float32)] boxes = torch.zeros((0, 4), dtype=torch.float32) - negative_target = {"boxes": boxes, - "labels": torch.zeros(0, dtype=torch.int64), - "image_id": 4, - "area": (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]), - "iscrowd": torch.zeros((0,), dtype=torch.int64)} + negative_target = { + "boxes": boxes, + "labels": torch.zeros(0, dtype=torch.int64), + "image_id": 4, + "area": (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]), + "iscrowd": torch.zeros((0,), dtype=torch.int64) + } targets = [negative_target] return images, targets + def test_retinanet(config): r = FasterRCNN.Model(config) - assert r + def test_load_backbone(config): r = FasterRCNN.Model(config) resnet_backbone = r.load_backbone() @@ -34,9 +38,11 @@ def test_load_backbone(config): x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] prediction = resnet_backbone(x) -# This test still fails, do we want a way to pass kwargs directly to method, instead of being limited by config structure? + +# This test still fails, do we want a way to pass kwargs directly to method, +# instead of being limited by config structure? # Need to create issue when I get online. -@pytest.mark.parametrize("num_classes",[1,2,10]) +@pytest.mark.parametrize("num_classes", [1, 2, 10]) def test_create_model(config, num_classes): config["num_classes"] = num_classes retinanet_model = FasterRCNN.Model(config).create_model() @@ -44,6 +50,7 @@ def test_create_model(config, num_classes): x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] predictions = retinanet_model(x) + def test_forward_empty(config): r = FasterRCNN.Model(config) model = r.create_model() diff --git a/tests/test_IoU.py b/tests/test_IoU.py index 287030f0c..20eda8a1d 100644 --- a/tests/test_IoU.py +++ b/tests/test_IoU.py @@ -1,4 +1,4 @@ -#Test IoU +# Test IoU from .conftest import download_release from deepforest import IoU from deepforest import main @@ -10,25 +10,26 @@ import geopandas as gpd import pandas as pd + def test_compute_IoU(m, tmpdir): csv_file = get_data("OSBS_029.csv") predictions = m.predict_file(csv_file=csv_file, root_dir=os.path.dirname(csv_file)) ground_truth = pd.read_csv(csv_file) - - predictions['geometry'] = predictions.apply(lambda x: shapely.geometry.box(x.xmin,x.ymin,x.xmax,x.ymax), axis=1) + + predictions['geometry'] = predictions.apply(lambda x: shapely.geometry.box(x.xmin, x.ymin, x.xmax, x.ymax), axis=1) predictions = gpd.GeoDataFrame(predictions, geometry='geometry') - - ground_truth['geometry'] = ground_truth.apply(lambda x: shapely.geometry.box(x.xmin,x.ymin,x.xmax,x.ymax), axis=1) - ground_truth = gpd.GeoDataFrame(ground_truth, geometry='geometry') - + + ground_truth['geometry'] = ground_truth.apply(lambda x: shapely.geometry.box(x.xmin, x.ymin, x.xmax, x.ymax), + axis=1) + ground_truth = gpd.GeoDataFrame(ground_truth, geometry='geometry') + ground_truth.label = 0 predictions.label = 0 - visualize.plot_prediction_dataframe( - df=predictions, - ground_truth=ground_truth, - root_dir=os.path.dirname(csv_file), - savedir=tmpdir) - + visualize.plot_prediction_dataframe(df=predictions, + ground_truth=ground_truth, + root_dir=os.path.dirname(csv_file), + savedir=tmpdir) + result = IoU.compute_IoU(ground_truth, predictions) assert result.shape[0] == ground_truth.shape[0] - assert sum(result.IoU) > 10 \ No newline at end of file + assert sum(result.IoU) > 10 diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index 639d55578..f25b2c38b 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -1,4 +1,4 @@ -#test callbacks +# test callbacks from deepforest import main from deepforest import callbacks import glob @@ -7,6 +7,7 @@ from pytorch_lightning.callbacks import ModelCheckpoint from deepforest import get_data + @pytest.mark.parametrize("every_n_epochs", [1, 2, 3]) def test_log_images(m, every_n_epochs, tmpdir): im_callback = callbacks.images_callback(savedir=tmpdir, every_n_epochs=every_n_epochs) @@ -15,14 +16,15 @@ def test_log_images(m, every_n_epochs, tmpdir): saved_images = glob.glob("{}/*.png".format(tmpdir)) assert len(saved_images) == 1 -def test_create_checkpoint(m, tmpdir): + +def test_create_checkpoint(m, tmpdir): checkpoint_callback = ModelCheckpoint( - dirpath=tmpdir, - save_top_k=1, - monitor="val_classification", - mode="max", - every_n_epochs=1, - ) + dirpath=tmpdir, + save_top_k=1, + monitor="val_classification", + mode="max", + every_n_epochs=1, + ) m.use_release() - m.create_trainer(callbacks = [checkpoint_callback]) - m.trainer.fit(m) \ No newline at end of file + m.create_trainer(callbacks=[checkpoint_callback]) + m.trainer.fit(m) diff --git a/tests/test_data.py b/tests/test_data.py index b978a96cb..056bf7d55 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -3,19 +3,21 @@ import deepforest from deepforest.utilities import read_config + # Make sure package data is present def test_get_data(): assert os.path.exists(deepforest.get_data("testfile_deepforest.csv")) - assert os.path.exists(deepforest.get_data("testfile_multi.csv")) - assert os.path.exists(deepforest.get_data("example.csv")) + assert os.path.exists(deepforest.get_data("testfile_multi.csv")) + assert os.path.exists(deepforest.get_data("example.csv")) assert os.path.exists(deepforest.get_data("2019_YELL_2_541000_4977000_image_crop.png")) assert os.path.exists(deepforest.get_data("OSBS_029.png")) assert os.path.exists(deepforest.get_data("OSBS_029.tif")) - assert os.path.exists(deepforest.get_data("SOAP_061.png")) + assert os.path.exists(deepforest.get_data("SOAP_061.png")) assert os.path.exists(deepforest.get_data("classes.csv")) + # Assert that the included config file matches the front of the repo. def test_matching_config(ROOT): config = read_config("{}/deepforest_config.yml".format(os.path.dirname(ROOT))) config_from_data_dir = read_config("{}/data/deepforest_config.yml".format(ROOT)) - assert config == config_from_data_dir \ No newline at end of file + assert config == config_from_data_dir diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 2495f80a3..0409aba2a 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1,4 +1,4 @@ -#test dataset model +# test dataset model from deepforest import get_data from deepforest import dataset from deepforest import utilities @@ -8,66 +8,69 @@ import pandas as pd import numpy as np import tempfile -import rasterio as rio +import rasterio as rio from deepforest.dataset import BoundingBoxDataset + def single_class(): csv_file = get_data("example.csv") - + return csv_file + def multi_class(): csv_file = get_data("testfile_multi.csv") - + return csv_file -@pytest.mark.parametrize("csv_file,label_dict",[(single_class(), {"Tree":0}), (multi_class(),{"Alive":0,"Dead":1})]) -def test_TreeDataset(csv_file, label_dict): + +@pytest.mark.parametrize("csv_file,label_dict", [(single_class(), {"Tree": 0}), (multi_class(), {"Alive": 0, "Dead": 1})]) +def test_tree_dataset(csv_file, label_dict): root_dir = os.path.dirname(get_data("OSBS_029.png")) - ds = dataset.TreeDataset(csv_file=csv_file, - root_dir=root_dir, - label_dict=label_dict) + ds = dataset.TreeDataset(csv_file=csv_file, root_dir=root_dir, label_dict=label_dict) raw_data = pd.read_csv(csv_file) - + assert len(ds) == len(raw_data.image_path.unique()) - + for i in range(len(ds)): - #Between 0 and 1 + # Between 0 and 1 path, image, targets = ds[i] assert image.max() <= 1 assert image.min() >= 0 - assert targets["boxes"].shape == (raw_data.shape[0],4) + assert targets["boxes"].shape == (raw_data.shape[0], 4) assert targets["labels"].shape == (raw_data.shape[0],) assert len(np.unique(targets["labels"])) == len(raw_data.label.unique()) - + + def test_single_class_with_empty(tmpdir): """Add fake empty annotations to test parsing """ csv_file1 = get_data("example.csv") csv_file2 = get_data("OSBS_029.csv") - + df1 = pd.read_csv(csv_file1) df2 = pd.read_csv(csv_file2) - df = pd.concat([df1,df2]) - - df.loc[df.image_path == "OSBS_029.tif","xmin"] = 0 - df.loc[df.image_path == "OSBS_029.tif","ymin"] = 0 - df.loc[df.image_path == "OSBS_029.tif","xmax"] = 0 - df.loc[df.image_path == "OSBS_029.tif","ymax"] = 0 - + df = pd.concat([df1, df2]) + + df.loc[df.image_path == "OSBS_029.tif", "xmin"] = 0 + df.loc[df.image_path == "OSBS_029.tif", "ymin"] = 0 + df.loc[df.image_path == "OSBS_029.tif", "xmax"] = 0 + df.loc[df.image_path == "OSBS_029.tif", "ymax"] = 0 + df.to_csv("{}_test_empty.csv".format(tmpdir)) - + root_dir = os.path.dirname(get_data("OSBS_029.png")) ds = dataset.TreeDataset(csv_file="{}_test_empty.csv".format(tmpdir), root_dir=root_dir, - label_dict={"Tree":0}) + label_dict={"Tree": 0}) assert len(ds) == 2 - #First image has annotations + # First image has annotations assert not torch.sum(ds[0][2]["boxes"]) == 0 - #Second image has no annotations + # Second image has no annotations assert torch.sum(ds[1][2]["boxes"]) == 0 - -@pytest.mark.parametrize("augment",[True,False]) -def test_TreeDataset_transform(augment): + + +@pytest.mark.parametrize("augment", [True, False]) +def test_tree_dataset_transform(augment): csv_file = get_data("example.csv") root_dir = os.path.dirname(csv_file) ds = dataset.TreeDataset(csv_file=csv_file, @@ -75,17 +78,18 @@ def test_TreeDataset_transform(augment): transforms=dataset.get_transform(augment=augment)) for i in range(len(ds)): - #Between 0 and 1 + # Between 0 and 1 path, image, targets = ds[i] assert image.max() <= 1 assert image.min() >= 0 assert targets["boxes"].shape == (79, 4) assert targets["labels"].shape == (79,) - + assert torch.is_tensor(targets["boxes"]) assert torch.is_tensor(targets["labels"]) assert torch.is_tensor(image) + def test_collate(): """Due to data augmentations the dataset class may yield empty bounding box annotations""" csv_file = get_data("example.csv") @@ -95,11 +99,12 @@ def test_collate(): transforms=dataset.get_transform(augment=False)) for i in range(len(ds)): - #Between 0 and 1 + # Between 0 and 1 batch = ds[i] collated_batch = utilities.collate_fn(batch) assert len(collated_batch) == 2 - + + def test_empty_collate(): """Due to data augmentations the dataset class may yield empty bounding box annotations""" csv_file = get_data("example.csv") @@ -109,21 +114,21 @@ def test_empty_collate(): transforms=dataset.get_transform(augment=False)) for i in range(len(ds)): - #Between 0 and 1 + # Between 0 and 1 batch = ds[i] collated_batch = utilities.collate_fn([None, batch, batch]) len(collated_batch[0]) == 2 + def test_dataloader(): csv_file = get_data("example.csv") root_dir = os.path.dirname(csv_file) - ds = dataset.TreeDataset(csv_file=csv_file, - root_dir=root_dir, - train=False) + ds = dataset.TreeDataset(csv_file=csv_file, root_dir=root_dir, train=False) image = next(iter(ds)) - #Assert image is channels first format + # Assert image is channels first format assert image.shape[0] == 3 - + + def test_multi_image_warning(): tmpdir = tempfile.gettempdir() csv_file1 = get_data("example.csv") @@ -133,31 +138,35 @@ def test_multi_image_warning(): df = pd.concat([df1, df2]) csv_file = "{}/multiple.csv".format(tmpdir) df.to_csv(csv_file) - + root_dir = os.path.dirname(csv_file1) ds = dataset.TreeDataset(csv_file=csv_file, root_dir=root_dir, transforms=dataset.get_transform(augment=False)) for i in range(len(ds)): - #Between 0 and 1 + # Between 0 and 1 batch = ds[i] collated_batch = utilities.collate_fn([None, batch, batch]) len(collated_batch[0]) == 2 - -@pytest.mark.parametrize("preload_images",[True, False]) -def test_TileDataset(preload_images): + + +@pytest.mark.parametrize("preload_images", [True, False]) +def test_tile_dataset(preload_images): tile_path = get_data("2019_YELL_2_528000_4978000_image_crop2.png") tile = rio.open(tile_path).read() - tile = np.moveaxis(tile, 0, 2) - ds = dataset.TileDataset(tile=tile, preload_images=preload_images, patch_size=100, patch_overlap=0) + tile = np.moveaxis(tile, 0, 2) + ds = dataset.TileDataset(tile=tile, + preload_images=preload_images, + patch_size=100, + patch_overlap=0) assert len(ds) > 0 - - #assert crop shape + + # assert crop shape assert ds[1].shape == (3, 100, 100) - -def test_BoundingBoxDataset(): + +def test_bounding_box_dataset(): # Create a sample dataframe df = pd.read_csv(get_data("OSBS_029.csv")) @@ -171,5 +180,4 @@ def test_BoundingBoxDataset(): item = ds[0] # Check the shape of the RGB tensor - assert item.shape == (3, 224,224) - + assert item.shape == (3, 224, 224) diff --git a/tests/test_download.py b/tests/test_download.py index 8e345b677..4e00f4b57 100644 --- a/tests/test_download.py +++ b/tests/test_download.py @@ -7,6 +7,7 @@ import rasterio as rio import pytest + def url(): return [ "https://map.dfg.ca.gov/arcgis/rest/services/Base_Remote_Sensing/NAIP_2020_CIR/ImageServer/", @@ -14,42 +15,39 @@ def url(): "https://orthos.its.ny.gov/arcgis/rest/services/wms/Latest/MapServer" ] + def boxes(): - return [ - (-124.112622, 40.493891, -124.111536, 40.49457), - (-114.12529, 51.072134, -114.12117, 51.07332), - (-73.763941, 41.111032, -73.763447, 41.111626) - ] + return [(-124.112622, 40.493891, -124.111536, 40.49457), + (-114.12529, 51.072134, -114.12117, 51.07332), + (-73.763941, 41.111032, -73.763447, 41.111626)] + def additional_params(): - return [ - None, - None, - {"format":"png"} - ] + return [None, None, {"format": "png"}] + def download_service(): - return [ - "exportImage", - "exportImage", - "export" - ] + return ["exportImage", "exportImage", "export"] + # Pair each URL with its corresponding box -url_box_pairs = list(zip(["CA.tif","MA.tif","NY.png"],url(), boxes(), additional_params(), download_service())) +url_box_pairs = list(zip(["CA.tif", "MA.tif", "NY.png"], url(), boxes(), additional_params(), download_service())) + + @pytest.mark.parametrize("image_name, url, box, params, download_service_name", url_box_pairs) -def test_download_ArcGIS_REST(tmpdir, image_name, url, box, params, download_service_name): +def test_download_arcgis_rest(tmpdir, image_name, url, box, params, download_service_name): async def run_test(): semaphore = asyncio.Semaphore(20) - limiter = AsyncLimiter(1,0.05) + limiter = AsyncLimiter(1, 0.05) xmin, ymin, xmax, ymax = box bbox_crs = "EPSG:4326" # Assuming WGS84 for bounding box CRS savedir = tmpdir - filename = await download.download_web_server(semaphore, limiter, url, xmin, ymin, xmax, ymax, bbox_crs, savedir, additional_params=params, image_name=image_name, download_service=download_service_name) - + filename = await download.download_web_server(semaphore, limiter, url, xmin, ymin, xmax, ymax, bbox_crs, + savedir, additional_params=params, image_name=image_name, + download_service=download_service_name) # Check the saved file assert os.path.exists(filename) - + # Confirm file has CRS with rio.open(filename) as src: if image_name.endswith('.tif'): @@ -58,24 +56,24 @@ async def run_test(): else: assert src.crs is None plt.imshow(cv2.imread(filename)[:, :, ::-1]) - + asyncio.run(run_test()) locations = [ [ 'https://aerial.openstreetmap.org.za/layer/ngi-aerial/{z}/{x}/{y}.jpg', - -33.9249, 18.4241, - -30.0000, 22.0000, - 6, - True, + -33.9249, 18.4241, + -30.0000, 22.0000, + 6, + True, 'dataset3', 'CapeTown.tiff' ], [ 'https://tile.openstreetmap.org/{z}/{x}/{y}.png', - 45.699,127, # From (latitude, longitude) - 30,148.492, # To (latitude, longitude) + 45.699, 127, # From (latitude, longitude) + 30, 148.492, # To (latitude, longitude) 6, # Zoom level True, 'dataset', @@ -83,18 +81,18 @@ async def run_test(): ], ] + # Parametrize test cases with different locations @pytest.mark.parametrize("source, lat0, lon0, lat1, lon1, zoom, save_image, save_dir, image_name", locations) -def test_download_TileMapServer(tmpdir, source, lat0, lon0, lat1, lon1, zoom, save_image, save_dir, image_name): +def test_download_tile_mapserver(tmpdir, source, lat0, lon0, lat1, lon1, zoom, save_image, save_dir, image_name): async def run_test(): semaphore = asyncio.Semaphore(20) limiter = AsyncLimiter(1, 0.05) save_path = os.path.join(tmpdir, image_name) - await download.download_web_server(semaphore, limiter, source, lat0, lon0, lat1, lon1, zoom, save_image=True, save_dir=tmpdir, image_name=image_name) - + await download.download_web_server(semaphore, limiter, source, lat0, lon0, lat1, lon1, zoom, save_image=True, + save_dir=tmpdir, image_name=image_name) # Check if the image file is saved assert os.path.exists(save_path) - # Confirm file format and load image img = cv2.imread(save_path) assert img is not None diff --git a/tests/test_environment.py b/tests/test_environment.py index 1455f27a9..c4ee8f442 100644 --- a/tests/test_environment.py +++ b/tests/test_environment.py @@ -1,4 +1,4 @@ -#test environment +# test environment def test_environment(): diff --git a/tests/test_evaluate.py b/tests/test_evaluate.py index 1f1ecbf8f..a8490b87d 100644 --- a/tests/test_evaluate.py +++ b/tests/test_evaluate.py @@ -1,5 +1,5 @@ -#Test evaluate -#Test IoU +# Test evaluate +# Test IoU from .conftest import download_release from deepforest.utilities import read_file from deepforest import evaluate @@ -11,15 +11,19 @@ import numpy as np import geopandas as gpd + def test_evaluate_image(m): csv_file = get_data("OSBS_029.csv") predictions = m.predict_file(csv_file=csv_file, root_dir=os.path.dirname(csv_file)) ground_truth = read_file(csv_file) predictions.label = 0 - result = evaluate.evaluate_image_boxes(predictions=predictions, ground_df=ground_truth, root_dir=os.path.dirname(csv_file)) - + result = evaluate.evaluate_image_boxes(predictions=predictions, + ground_df=ground_truth, + root_dir=os.path.dirname(csv_file)) + assert result.shape[0] == ground_truth.shape[0] - assert sum(result.IoU) > 10 + assert sum(result.IoU) > 10 + def test_evaluate_boxes(m, tmpdir): csv_file = get_data("OSBS_029.csv") @@ -27,62 +31,71 @@ def test_evaluate_boxes(m, tmpdir): predictions.label = "Tree" ground_truth = read_file(csv_file) predictions = predictions.loc[range(10)] - results = evaluate.evaluate_boxes(predictions=predictions, ground_df=ground_truth, root_dir=os.path.dirname(csv_file)) - + results = evaluate.evaluate_boxes(predictions=predictions, + ground_df=ground_truth, + root_dir=os.path.dirname(csv_file)) + assert results["results"].shape[0] == ground_truth.shape[0] assert results["box_recall"] > 0.1 - assert results["class_recall"].shape == (1,4) + assert results["class_recall"].shape == (1, 4) assert results["class_recall"].recall.values == 1 assert results["class_recall"].precision.values == 1 assert "score" in results["results"].columns assert results["results"].true_label.unique() == "Tree" + def test_evaluate_boxes_multiclass(): csv_file = get_data("testfile_multi.csv") ground_truth = read_file(csv_file) ground_truth["label"] = ground_truth.label.astype("category").cat.codes - - #Manipulate the data to create some false positives + + # Manipulate the data to create some false positives predictions = ground_truth.copy() predictions["score"] = 1 predictions.iloc[[36, 35, 34], predictions.columns.get_indexer(['label'])] - results = evaluate.evaluate_boxes(predictions=predictions, ground_df=ground_truth, root_dir=os.path.dirname(csv_file)) - + results = evaluate.evaluate_boxes(predictions=predictions, + ground_df=ground_truth, + root_dir=os.path.dirname(csv_file)) + assert results["results"].shape[0] == ground_truth.shape[0] - assert results["class_recall"].shape == (2,4) - + assert results["class_recall"].shape == (2, 4) + + def test_evaluate_boxes_save_images(tmpdir): csv_file = get_data("testfile_multi.csv") ground_truth = read_file(csv_file) ground_truth["label"] = ground_truth.label.astype("category").cat.codes - - #Manipulate the data to create some false positives + + # Manipulate the data to create some false positives predictions = ground_truth.copy() predictions["score"] = 1 predictions.iloc[[36, 35, 34], predictions.columns.get_indexer(['label'])] - results = evaluate.evaluate_boxes(predictions=predictions, ground_df=ground_truth, root_dir=os.path.dirname(csv_file), savedir=tmpdir) - assert all([os.path.exists("{}/{}".format(tmpdir,x)) for x in ground_truth.image_path]) + results = evaluate.evaluate_boxes(predictions=predictions, + ground_df=ground_truth, + root_dir=os.path.dirname(csv_file), + savedir=tmpdir) + assert all([os.path.exists("{}/{}".format(tmpdir, x)) for x in ground_truth.image_path]) + def test_evaluate_empty(m): m = main.deepforest() m.config["score_thresh"] = 0.8 csv_file = get_data("OSBS_029.csv") root_dir = os.path.dirname(csv_file) - results = m.evaluate(csv_file, root_dir, iou_threshold = 0.4) - - #Does this make reasonable predictions, we know the model works. + results = m.evaluate(csv_file, root_dir, iou_threshold=0.4) + + # Does this make reasonable predictions, we know the model works. assert np.isnan(results["box_precision"]) assert results["box_recall"] == 0 - + + @pytest.fixture def sample_results(): # Create a sample DataFrame for testing - data = { - 'true_label': [1, 1, 2], - 'predicted_label': [1, 2, 1] - } + data = {'true_label': [1, 1, 2], 'predicted_label': [1, 2, 1]} return pd.DataFrame(data) + def test_compute_class_recall(sample_results): # Test case with sample data expected_recall = pd.DataFrame({ @@ -94,7 +107,8 @@ def test_compute_class_recall(sample_results): assert evaluate.compute_class_recall(sample_results).equals(expected_recall) -@pytest.mark.parametrize("root_dir",[None,"tmpdir"]) + +@pytest.mark.parametrize("root_dir", [None, "tmpdir"]) def test_point_recall_image(root_dir, tmpdir): img_path = get_data("OSBS_029.png") if root_dir == "tmpdir": @@ -123,12 +137,13 @@ def test_point_recall_image(root_dir, tmpdir): result = evaluate._point_recall_image_(predictions, ground_df, root_dir=root_dir, savedir=savedir) # check the output, 1 match of 2 ground truth - assert all(result.predicted_label.isnull().values == [False,True]) + assert all(result.predicted_label.isnull().values == [False, True]) assert isinstance(result, gpd.GeoDataFrame) assert "predicted_label" in result.columns assert "true_label" in result.columns assert "geometry" in result.columns + def test_point_recall(): # create sample dataframes predictions = pd.DataFrame({ @@ -148,4 +163,4 @@ def test_point_recall(): results = evaluate.point_recall(ground_df=ground_df, predictions=predictions) assert results["box_recall"] == 0.5 - assert results["class_recall"].recall[0] == 1 \ No newline at end of file + assert results["class_recall"].recall[0] == 1 diff --git a/tests/test_main.py b/tests/test_main.py index c4bebb5bd..6bcba5525 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -28,81 +28,86 @@ @pytest.fixture() def two_class_m(): - m = main.deepforest(config_args={"num_classes":2}, label_dict={"Alive":0,"Dead":1}) - m.config["train"]["csv_file"] = get_data("testfile_multi.csv") + m = main.deepforest(config_args={"num_classes": 2}, + label_dict={ + "Alive": 0, + "Dead": 1 + }) + m.config["train"]["csv_file"] = get_data("testfile_multi.csv") m.config["train"]["root_dir"] = os.path.dirname(get_data("testfile_multi.csv")) m.config["train"]["fast_dev_run"] = True m.config["batch_size"] = 2 - - m.config["validation"]["csv_file"] = get_data("testfile_multi.csv") + + m.config["validation"]["csv_file"] = get_data("testfile_multi.csv") m.config["validation"]["root_dir"] = os.path.dirname(get_data("testfile_multi.csv")) m.config["validation"]["val_accuracy_interval"] = 1 m.create_trainer() - + return m @pytest.fixture() def m(download_release): m = main.deepforest() - m.config["train"]["csv_file"] = get_data("example.csv") + m.config["train"]["csv_file"] = get_data("example.csv") m.config["train"]["root_dir"] = os.path.dirname(get_data("example.csv")) m.config["train"]["fast_dev_run"] = True m.config["batch_size"] = 2 - - m.config["validation"]["csv_file"] = get_data("example.csv") + + m.config["validation"]["csv_file"] = get_data("example.csv") m.config["validation"]["root_dir"] = os.path.dirname(get_data("example.csv")) - m.config["workers"] = 0 + m.config["workers"] = 0 m.config["validation"]["val_accuracy_interval"] = 1 m.config["train"]["epochs"] = 2 - + m.create_trainer() m.use_release(check_release=False) - + return m @pytest.fixture() def m_without_release(): m = main.deepforest() - m.config["train"]["csv_file"] = get_data("example.csv") + m.config["train"]["csv_file"] = get_data("example.csv") m.config["train"]["root_dir"] = os.path.dirname(get_data("example.csv")) m.config["train"]["fast_dev_run"] = True m.config["batch_size"] = 2 - - m.config["validation"]["csv_file"] = get_data("example.csv") + + m.config["validation"]["csv_file"] = get_data("example.csv") m.config["validation"]["root_dir"] = os.path.dirname(get_data("example.csv")) - m.config["workers"] = 0 + m.config["workers"] = 0 m.config["validation"]["val_accuracy_interval"] = 1 m.config["train"]["epochs"] = 2 - + m.create_trainer() return m @pytest.fixture() def raster_path(): - return get_data(path='OSBS_029.tif') + return get_data(path='OSBS_029.tif') def big_file(): tmpdir = tempfile.gettempdir() csv_file = get_data("OSBS_029.csv") image_path = get_data("OSBS_029.png") - df = pd.read_csv(csv_file) - + df = pd.read_csv(csv_file) + big_frame = [] for x in range(3): - img = Image.open("{}/{}".format(os.path.dirname(csv_file), df.image_path.unique()[0])) + img = Image.open("{}/{}".format(os.path.dirname(csv_file), + df.image_path.unique()[0])) cv2.imwrite("{}/{}.png".format(tmpdir, x), np.array(img)) new_df = df.copy() new_df.image_path = "{}.png".format(x) big_frame.append(new_df) - + big_frame = pd.concat(big_frame) - big_frame.to_csv("{}/annotations.csv".format(tmpdir)) - + big_frame.to_csv("{}/annotations.csv".format(tmpdir)) + return "{}/annotations.csv".format(tmpdir) @@ -130,18 +135,25 @@ def test_tensorboard_logger(m, tmpdir): def test_use_bird_release(m): - imgpath = get_data("AWPE Pigeon Lake 2020 DJI_0005.JPG") + imgpath = get_data("AWPE Pigeon Lake 2020 DJI_0005.JPG") m.use_bird_release() boxes = m.predict_image(path=imgpath) assert not boxes.empty def test_train_empty(m, tmpdir): - empty_csv = pd.DataFrame({"image_path":["OSBS_029.png","OSBS_029.tif"],"xmin":[0,10],"xmax":[0,20],"ymin":[0,20],"ymax":[0,30],"label":["Tree","Tree"]}) + empty_csv = pd.DataFrame({ + "image_path": ["OSBS_029.png", "OSBS_029.tif"], + "xmin": [0, 10], + "xmax": [0, 20], + "ymin": [0, 20], + "ymax": [0, 30], + "label": ["Tree", "Tree"] + }) empty_csv.to_csv("{}/empty.csv".format(tmpdir)) m.config["train"]["csv_file"] = "{}/empty.csv".format(tmpdir) m.config["batch_size"] = 2 - m.create_trainer() + m.create_trainer() m.trainer.fit(m) @@ -152,12 +164,12 @@ def test_validation_step(m): m.create_trainer() m.trainer.validate(m) # assert no weights have changed - for p1, p2 in zip(before.named_parameters(), m.named_parameters()): + for p1, p2 in zip(before.named_parameters(), m.named_parameters()): assert p1[1].ne(p2[1]).sum() == 0 # Test train with each architecture -@pytest.mark.parametrize("architecture",["retinanet","FasterRCNN"]) +@pytest.mark.parametrize("architecture", ["retinanet", "FasterRCNN"]) def test_train_single(m_without_release, architecture): m_without_release.config["architecture"] = architecture m_without_release.create_model() @@ -177,76 +189,86 @@ def test_train_multi(two_class_m): def test_train_no_validation(m): - m.config["train"]["fast_dev_run"] = False + m.config["train"]["fast_dev_run"] = False m.config["validation"]["csv_file"] = None - m.config["validation"]["root_dir"] = None + m.config["validation"]["root_dir"] = None m.create_trainer(limit_train_batches=1) m.trainer.fit(m) def test_predict_image_empty(m): - image = np.random.random((400,400,3)).astype("float32") - prediction = m.predict_image(image = image) - + image = np.random.random((400, 400, 3)).astype("float32") + prediction = m.predict_image(image=image) + assert prediction is None def test_predict_image_fromfile(m): path = get_data(path="2019_YELL_2_528000_4978000_image_crop2.png") - prediction = m.predict_image(path = path) - + prediction = m.predict_image(path=path) + assert isinstance(prediction, pd.DataFrame) - assert set(prediction.columns) == {"xmin","ymin","xmax","ymax","label","score","image_path"} + assert set(prediction.columns) == { + "xmin", "ymin", "xmax", "ymax", "label", "score", "image_path" + } def test_predict_image_fromarray(m): image_path = get_data(path="2019_YELL_2_528000_4978000_image_crop2.png") - + # assert error of dtype with pytest.raises(TypeError): image = Image.open(image_path) - prediction = m.predict_image(image = image) - + prediction = m.predict_image(image=image) + image = np.array(Image.open(image_path).convert("RGB")) with pytest.warns(UserWarning, match="Image type is uint8, transforming to float32"): - prediction = m.predict_image(image = image) - + prediction = m.predict_image(image=image) + assert isinstance(prediction, pd.DataFrame) - assert set(prediction.columns) == {"xmin","ymin","xmax","ymax","label","score"} + assert set(prediction.columns) == {"xmin", "ymin", "xmax", "ymax", "label", "score"} def test_predict_return_plot(m): image = get_data(path="2019_YELL_2_528000_4978000_image_crop2.png") image = np.array(Image.open(image)) image = image.astype('float32') - plot = m.predict_image(image = image, return_plot=True) + plot = m.predict_image(image=image, return_plot=True) assert isinstance(plot, np.ndarray) + def test_predict_big_file(m, tmpdir): m.config["train"]["fast_dev_run"] = False - m.create_trainer() + m.create_trainer() csv_file = big_file() original_file = pd.read_csv(csv_file) - df = m.predict_file(csv_file=csv_file, root_dir = os.path.dirname(csv_file), savedir=tmpdir) - assert set(df.columns) == {'label', 'score', 'image_path', 'geometry',"xmin","ymin","xmax","ymax"} - + df = m.predict_file(csv_file=csv_file, + root_dir=os.path.dirname(csv_file), + savedir=tmpdir) + assert set(df.columns) == { + 'label', 'score', 'image_path', 'geometry', "xmin", "ymin", "xmax", "ymax" + } + printed_plots = glob.glob("{}/*.png".format(tmpdir)) assert len(printed_plots) == len(original_file.image_path.unique()) + def test_predict_small_file(m, tmpdir): csv_file = get_data("OSBS_029.csv") original_file = pd.read_csv(csv_file) - df = m.predict_file(csv_file, root_dir = os.path.dirname(csv_file), savedir=tmpdir) - assert set(df.columns) == {'label', 'score', 'image_path', 'geometry',"xmin","ymin","xmax","ymax"} + df = m.predict_file(csv_file, root_dir=os.path.dirname(csv_file), savedir=tmpdir) + assert set(df.columns) == { + 'label', 'score', 'image_path', 'geometry', "xmin", "ymin", "xmax", "ymax" + } printed_plots = glob.glob("{}/*.png".format(tmpdir)) assert len(printed_plots) == len(original_file.image_path.unique()) -@pytest.mark.parametrize("batch_size",[1, 2]) +@pytest.mark.parametrize("batch_size", [1, 2]) def test_predict_dataloader(m, batch_size, raster_path): m.config["batch_size"] = batch_size - tile = np.array(Image.open(raster_path)) - ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=100) + tile = np.array(Image.open(raster_path)) + ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=100) dl = m.predict_dataloader(ds) batch = next(iter(dl)) batch.shape[0] == batch_size @@ -260,9 +282,11 @@ def test_predict_tile(m, raster_path): patch_size=300, patch_overlap=0.1, return_plot=False) - + assert isinstance(prediction, pd.DataFrame) - assert set(prediction.columns) == {"xmin","ymin","xmax","ymax","label","score","image_path"} + assert set(prediction.columns) == { + "xmin", "ymin", "xmax", "ymax", "label", "score", "image_path" + } assert not prediction.empty @@ -279,20 +303,21 @@ def test_predict_tile_from_array(m, patch_overlap, raster_path): assert not prediction.empty -@pytest.mark.parametrize("patch_overlap",[0.1, 0]) +@pytest.mark.parametrize("patch_overlap", [0.1, 0]) def test_predict_tile_from_array_with_return_plot(m, patch_overlap, raster_path): - #test predict numpy image + # test predict numpy image image = np.array(Image.open(raster_path)) m.config["train"]["fast_dev_run"] = False - m.create_trainer() - prediction = m.predict_tile(image = image, - patch_size = 300, - patch_overlap = patch_overlap, - return_plot = True, - color=(0,255,0)) + m.create_trainer() + prediction = m.predict_tile(image=image, + patch_size=300, + patch_overlap=patch_overlap, + return_plot=True, + color=(0, 255, 0)) assert isinstance(prediction, np.ndarray) assert prediction.size > 0 + def test_predict_tile_no_mosaic(m, raster_path): # test no mosaic, return a tuple of crop and prediction m.config["train"]["fast_dev_run"] = False @@ -304,7 +329,7 @@ def test_predict_tile_no_mosaic(m, raster_path): mosaic=False) assert len(prediction) == 4 assert len(prediction[0]) == 2 - assert prediction[0][1].shape == (300,300, 3) + assert prediction[0][1].shape == (300, 300, 3) def test_evaluate(m, tmpdir): @@ -320,12 +345,13 @@ def test_evaluate(m, tmpdir): assert results["results"].predicted_label.dropna().unique()[0] == "Tree" assert results["predictions"].shape[0] > 0 assert results["predictions"].label.dropna().unique()[0] == "Tree" - + df = pd.read_csv(csv_file) assert results["results"].shape[0] == df.shape[0] + def test_train_callbacks(m): - csv_file = get_data("example.csv") + csv_file = get_data("example.csv") root_dir = os.path.dirname(csv_file) train_ds = m.load_dataset(csv_file, root_dir=root_dir) @@ -344,27 +370,29 @@ def on_train_end(self, trainer, pl_module): trainer = Trainer(fast_dev_run=True) trainer.fit(m, train_ds) + def test_custom_config_file_path(ROOT, tmpdir): - m = main.deepforest(config_file='{}/deepforest_config.yml'.format(os.path.dirname(ROOT))) + m = main.deepforest( + config_file='{}/deepforest_config.yml'.format(os.path.dirname(ROOT))) def test_save_and_reload_checkpoint(m, tmpdir): - img_path = get_data(path="2019_YELL_2_528000_4978000_image_crop2.png") + img_path = get_data(path="2019_YELL_2_528000_4978000_image_crop2.png") m.config["train"]["fast_dev_run"] = True m.create_trainer() # save the prediction dataframe after training and # compare with prediction after reload checkpoint - m.trainer.fit(m) - pred_after_train = m.predict_image(path = img_path) + m.trainer.fit(m) + pred_after_train = m.predict_image(path=img_path) m.save_model("{}/checkpoint.pl".format(tmpdir)) # reload the checkpoint to model object after = main.deepforest.load_from_checkpoint("{}/checkpoint.pl".format(tmpdir)) - pred_after_reload = after.predict_image(path = img_path) + pred_after_reload = after.predict_image(path=img_path) assert not pred_after_train.empty assert not pred_after_reload.empty - pd.testing.assert_frame_equal(pred_after_train,pred_after_reload) + pd.testing.assert_frame_equal(pred_after_train, pred_after_reload) def test_save_and_reload_weights(m, tmpdir): @@ -379,12 +407,13 @@ def test_save_and_reload_weights(m, tmpdir): # reload the checkpoint to model object after = main.deepforest() - after.model.load_state_dict(torch.load("{}/checkpoint.pt".format(tmpdir), weights_only=True)) + after.model.load_state_dict( + torch.load("{}/checkpoint.pt".format(tmpdir), weights_only=True)) pred_after_reload = after.predict_image(path=img_path) assert not pred_after_train.empty assert not pred_after_reload.empty - pd.testing.assert_frame_equal(pred_after_train,pred_after_reload) + pd.testing.assert_frame_equal(pred_after_train, pred_after_reload) def test_reload_multi_class(two_class_m, tmpdir): @@ -393,9 +422,10 @@ def test_reload_multi_class(two_class_m, tmpdir): two_class_m.trainer.fit(two_class_m) two_class_m.save_model("{}/checkpoint.pl".format(tmpdir)) before = two_class_m.trainer.validate(two_class_m) - + # reload - old_model = main.deepforest.load_from_checkpoint("{}/checkpoint.pl".format(tmpdir), weights_only=True) + old_model = main.deepforest.load_from_checkpoint("{}/checkpoint.pl".format(tmpdir), + weights_only=True) old_model.config = two_class_m.config assert old_model.config["num_classes"] == 2 old_model.create_trainer() @@ -405,22 +435,23 @@ def test_reload_multi_class(two_class_m, tmpdir): def test_override_transforms(): + def get_transform(augment): """This is the new transform""" if augment: print("I'm a new augmentation!") - transform = A.Compose([ - A.HorizontalFlip(p=0.5), - ToTensorV2() - ], bbox_params=A.BboxParams(format='pascal_voc',label_fields=["category_ids"])) - + transform = A.Compose( + [A.HorizontalFlip(p=0.5), ToTensorV2()], + bbox_params=A.BboxParams(format='pascal_voc', + label_fields=["category_ids"])) + else: transform = ToTensorV2() return transform m = main.deepforest(transforms=get_transform) - csv_file = get_data("example.csv") + csv_file = get_data("example.csv") root_dir = os.path.dirname(csv_file) train_ds = m.load_dataset(csv_file, root_dir=root_dir, augment=True) @@ -433,9 +464,9 @@ def test_over_score_thresh(m): img = get_data("OSBS_029.png") original_score_thresh = m.model.score_thresh m.model.score_thresh = 0.8 - + # trigger update - boxes = m.predict_image(path = img) + boxes = m.predict_image(path=img) assert all(boxes.score > 0.8) assert m.model.score_thresh == 0.8 @@ -452,13 +483,17 @@ def test_iou_metric(m): def test_config_args(m): assert not m.config["num_classes"] == 2 - m = main.deepforest(config_args={"num_classes":2}, label_dict={"Alive":0,"Dead":1}) + m = main.deepforest(config_args={"num_classes": 2}, + label_dict={ + "Alive": 0, + "Dead": 1 + }) assert m.config["num_classes"] == 2 # These call also be nested for train and val arguments assert not m.config["train"]["epochs"] == 7 - m2 = main.deepforest(config_args={"train":{"epochs":7}}) + m2 = main.deepforest(config_args={"train": {"epochs": 7}}) assert m2.config["train"]["epochs"] == 7 @@ -475,9 +510,12 @@ def existing_loader(m, tmpdir): # Copy the new images to the tmpdir train.image_path.unique() image_path = train.image_path.unique()[0] - shutil.copyfile("{}/{}".format(m.config["train"]["root_dir"], image_path), tmpdir.strpath + "/{}".format(image_path)) - shutil.copyfile("{}/{}".format(m.config["train"]["root_dir"], image_path), tmpdir.strpath + "/{}".format(image_path + "2")) - shutil.copyfile("{}/{}".format(m.config["train"]["root_dir"], image_path), tmpdir.strpath + "/{}".format(image_path + "3")) + shutil.copyfile("{}/{}".format(m.config["train"]["root_dir"], image_path), + tmpdir.strpath + "/{}".format(image_path)) + shutil.copyfile("{}/{}".format(m.config["train"]["root_dir"], image_path), + tmpdir.strpath + "/{}".format(image_path + "2")) + shutil.copyfile("{}/{}".format(m.config["train"]["root_dir"], image_path), + tmpdir.strpath + "/{}".format(image_path + "3")) existing_loader = m.load_dataset(csv_file="{}/train.csv".format(tmpdir.strpath), root_dir=tmpdir.strpath, batch_size=m.config["batch_size"] + 1) @@ -494,7 +532,7 @@ def test_load_existing_train_dataloader(m, tmpdir, existing_loader): m.create_trainer() m.trainer.fit(m) batch = next(iter(m.trainer.train_dataloader)) - assert len(batch[0]) == m.config["batch_size"] + assert len(batch[0]) == m.config["batch_size"] # Existing train dataloader m.config["train"]["csv_file"] = "{}/train.csv".format(tmpdir.strpath) @@ -519,29 +557,30 @@ def test_existing_val_dataloader(m, tmpdir, existing_loader): assert len(batch[0]) == m.config["batch_size"] + 1 -def test_existing_predict_dataloader(m, tmpdir): +def test_existing_predict_dataloader(m, tmpdir): # Predict datasets yield only images - ds = dataset.TileDataset(tile=np.random.random((400,400,3)).astype("float32"), patch_overlap=0.1, patch_size=100) + ds = dataset.TileDataset(tile=np.random.random((400, 400, 3)).astype("float32"), + patch_overlap=0.1, + patch_size=100) existing_loader = m.predict_dataloader(ds) batches = m.trainer.predict(m, existing_loader) len(batches[0]) == m.config["batch_size"] + 1 # Test train with each scheduler -@pytest.mark.parametrize("scheduler,expected",[("cosine","CosineAnnealingLR"), - ("lambdaLR","LambdaLR"), - ("multiplicativeLR","MultiplicativeLR"), - ("stepLR","StepLR"), - ("multistepLR","MultiStepLR"), - ("exponentialLR","ExponentialLR"), - ("reduceLROnPlateau","ReduceLROnPlateau")]) +@pytest.mark.parametrize("scheduler,expected", + [("cosine", "CosineAnnealingLR"), ("lambdaLR", "LambdaLR"), + ("multiplicativeLR", "MultiplicativeLR"), ("stepLR", "StepLR"), + ("multistepLR", "MultiStepLR"), + ("exponentialLR", "ExponentialLR"), + ("reduceLROnPlateau", "ReduceLROnPlateau")]) def test_configure_optimizers(scheduler, expected): scheduler_config = { "type": scheduler, "params": { "T_max": 10, "eta_min": 0.00001, - "lr_lambda": lambda epoch: 0.95 ** epoch, # For lambdaLR and multiplicativeLR + "lr_lambda": lambda epoch: 0.95**epoch, # For lambdaLR and multiplicativeLR "step_size": 30, # For stepLR "gamma": 0.1, # For stepLR, multistepLR, and exponentialLR "milestones": [50, 100], # For multistepLR @@ -558,10 +597,10 @@ def test_configure_optimizers(scheduler, expected): }, "expected": expected } - + annotations_file = get_data("testfile_deepforest.csv") root_dir = os.path.dirname(annotations_file) - + config_args = { "train": { "lr": 0.01, @@ -576,21 +615,24 @@ def test_configure_optimizers(scheduler, expected): "root_dir": root_dir } } - + # Initialize the model with the config arguments m = main.deepforest(config_args=config_args) - + # Create and run the trainer m.create_trainer(limit_train_batches=1.0) m.trainer.fit(m) - + # Assert the scheduler type - assert type(m.trainer.lr_scheduler_configs[0].scheduler).__name__ == scheduler_config["expected"], f"Scheduler type mismatch for {scheduler_config['type']}" + assert type(m.trainer.lr_scheduler_configs[0].scheduler).__name__ == scheduler_config[ + "expected"], f"Scheduler type mismatch for {scheduler_config['type']}" + @pytest.fixture() def crop_model(): return model.CropModel() + def test_predict_tile_with_crop_model(m, config): raster_path = get_data("SOAP_061.png") patch_size = 400 @@ -598,8 +640,6 @@ def test_predict_tile_with_crop_model(m, config): iou_threshold = 0.15 return_plot = False mosaic = True - - # Set up the crop model crop_model = model.CropModel() @@ -607,13 +647,16 @@ def test_predict_tile_with_crop_model(m, config): m.config["train"]["fast_dev_run"] = False m.create_trainer() result = m.predict_tile(raster_path=raster_path, - patch_size=patch_size, - patch_overlap=patch_overlap, - iou_threshold=iou_threshold, - return_plot=return_plot, - mosaic=mosaic, - crop_model=crop_model) + patch_size=patch_size, + patch_overlap=patch_overlap, + iou_threshold=iou_threshold, + return_plot=return_plot, + mosaic=mosaic, + crop_model=crop_model) # Assert the result assert isinstance(result, pd.DataFrame) - assert set(result.columns) == {"xmin", "ymin", "xmax", "ymax", "label", "score", "cropmodel_label","cropmodel_score","image_path"} \ No newline at end of file + assert set(result.columns) == { + "xmin", "ymin", "xmax", "ymax", "label", "score", "cropmodel_label", + "cropmodel_score", "image_path" + } diff --git a/tests/test_model.py b/tests/test_model.py index 52b0fa0c1..426ae0f5e 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -8,46 +8,56 @@ from torchvision import transforms import cv2 + # The model object is achitecture agnostic container. def test_model_no_args(config): with pytest.raises(ValueError): model.Model(config) + # The model object is achitecture agnostic container. def test_model_no_args(config): with pytest.raises(ValueError): model.Model(config) + @pytest.fixture() def crop_model(): crop_model = model.CropModel(num_classes=2) return crop_model -def test_crop_model(crop_model): # Use pytest tempdir fixture to create a temporary directory + +def test_crop_model( + crop_model): # Use pytest tempdir fixture to create a temporary directory # Test forward pass x = torch.rand(4, 3, 224, 224) output = crop_model.forward(x) assert output.shape == (4, 2) - + # Test training step batch = (x, torch.tensor([0, 1, 0, 1])) loss = crop_model.training_step(batch, batch_idx=0) assert isinstance(loss, torch.Tensor) - + # Test validation step val_batch = (x, torch.tensor([0, 1, 0, 1])) val_loss = crop_model.validation_step(val_batch, batch_idx=0) assert isinstance(val_loss, torch.Tensor) + def test_crop_model_train(crop_model, tmpdir): df = pd.read_csv(get_data("testfile_multi.csv")) boxes = df[['xmin', 'ymin', 'xmax', 'ymax']].values.tolist() root_dir = os.path.dirname(get_data("SOAP_061.png")) images = df.image_path.values - crop_model.write_crops(boxes=boxes,labels=df.label.values, root_dir=root_dir, images=images, savedir=tmpdir) + crop_model.write_crops(boxes=boxes, + labels=df.label.values, + root_dir=root_dir, + images=images, + savedir=tmpdir) - #Create a trainer + # Create a trainer crop_model.create_trainer(fast_dev_run=True) crop_model.load_from_disk(train_dir=tmpdir, val_dir=tmpdir) @@ -62,10 +72,11 @@ def test_crop_model_train(crop_model, tmpdir): crop_model.trainer.fit(crop_model) crop_model.trainer.validate(crop_model) + def test_crop_model_custom_transform(): # Create a dummy instance of CropModel crop_model = model.CropModel(num_classes=2) - + def custom_transform(self, augment): data_transforms = [] data_transforms.append(transforms.ToTensor()) @@ -81,4 +92,3 @@ def custom_transform(self, augment): crop_model.get_transform = custom_transform output = crop_model.forward(x) assert output.shape == (4, 2) - diff --git a/tests/test_multiprocessing.py b/tests/test_multiprocessing.py index ebaaf9697..3db438388 100644 --- a/tests/test_multiprocessing.py +++ b/tests/test_multiprocessing.py @@ -8,7 +8,8 @@ import shutil import yaml -@pytest.mark.parametrize("num_workers",[0 ,2]) + +@pytest.mark.parametrize("num_workers", [0, 2]) def test_predict_tile_workers(m, num_workers): # Default workers is 0 original_workers = m.config["workers"] @@ -24,6 +25,7 @@ def test_predict_tile_workers(m, num_workers): dataloader = m.predict_dataloader(ds) assert dataloader.num_workers == num_workers + def test_predict_tile_workers_config(tmpdir): # Open config file and change workers to 1, save to tmpdir config_file = get_data("deepforest_config.yml") @@ -34,15 +36,13 @@ def test_predict_tile_workers_config(tmpdir): x["workers"] = 1 with open(tmp_config_file, "w+") as f: f.write(yaml.dump(x)) - + m = main.deepforest(config_file=tmp_config_file) csv_file = get_data("OSBS_029.csv") # make a dataset ds = dataset.TreeDataset(csv_file=csv_file, - root_dir=os.path.dirname(csv_file), - transforms=None, - train=False) + root_dir=os.path.dirname(csv_file), + transforms=None, + train=False) dataloader = m.predict_dataloader(ds) assert dataloader.num_workers == 1 - - diff --git a/tests/test_preprocess.py b/tests/test_preprocess.py index 442ade724..52ae1e5ea 100644 --- a/tests/test_preprocess.py +++ b/tests/test_preprocess.py @@ -16,6 +16,7 @@ from shapely import geometry from shapely import geometry + @pytest.fixture() def config(): config = utilities.read_config("deepforest_config.yml") @@ -31,12 +32,14 @@ def config(): return config + @pytest.fixture() def geodataframe(): csv_file = get_data("OSBS_029.csv") annotations = utilities.read_file(csv_file) return annotations + @pytest.fixture() def image(config): raster = Image.open(config["path_to_raster"]) @@ -55,7 +58,7 @@ def test_select_annotations(config, image): image_annotations = utilities.read_file(csv_file) selected_annotations = preprocess.select_annotations(image_annotations, - window = windows[0]) + window=windows[0]) # The largest box cannot be off the edge of the window assert selected_annotations.geometry.bounds.minx.min() >= 0 @@ -63,6 +66,7 @@ def test_select_annotations(config, image): assert selected_annotations.geometry.bounds.maxx.max() <= 300 assert selected_annotations.geometry.bounds.maxy.max() <= 300 + @pytest.mark.parametrize("input_type", ["path", "dataframe"]) def test_split_raster(config, tmpdir, input_type, geodataframe): """Split raster into crops with overlaps to maintain all annotations""" @@ -70,7 +74,7 @@ def test_split_raster(config, tmpdir, input_type, geodataframe): annotations = utilities.read_pascal_voc( get_data("2019_YELL_2_528000_4978000_image_crop2.xml")) annotations.to_csv("{}/example.csv".format(tmpdir), index=False) - #annotations.label = 0 + # annotations.label = 0 if input_type == "path": annotations_file = get_data("OSBS_029.csv") @@ -86,7 +90,8 @@ def test_split_raster(config, tmpdir, input_type, geodataframe): # Returns a 7 column pandas array assert output_annotations.shape[1] == 7 assert not output_annotations.empty - + + def test_split_raster_no_annotations(config, tmpdir): """Split raster into crops with overlaps to maintain all annotations""" raster = get_data("2019_YELL_2_528000_4978000_image_crop2.png") @@ -103,24 +108,23 @@ def test_split_raster_no_annotations(config, tmpdir): # Assert that all output_crops exist for crop in output_crops: assert os.path.exists(crop) - + def test_split_raster_from_image(config, tmpdir, geodataframe): r = rasterio.open(config["path_to_raster"]).read() r = np.rollaxis(r, 0, 3) - annotations_file = preprocess.split_raster( - numpy_image=r, - annotations_file=geodataframe, - save_dir=tmpdir, - patch_size=config["patch_size"], - patch_overlap=config["patch_overlap"], - image_name="OSBS_029.tif") + annotations_file = preprocess.split_raster(numpy_image=r, + annotations_file=geodataframe, + save_dir=tmpdir, + patch_size=config["patch_size"], + patch_overlap=config["patch_overlap"], + image_name="OSBS_029.tif") assert not annotations_file.empty + @pytest.mark.parametrize("allow_empty", [True, False]) def test_split_raster_empty(tmpdir, config, allow_empty): - # Blank annotations file blank_annotations = pd.DataFrame({ "image_path": "OSBS_029.tif", @@ -143,15 +147,16 @@ def test_split_raster_empty(tmpdir, config, allow_empty): patch_overlap=config["patch_overlap"], allow_empty=allow_empty) else: - annotations_file = preprocess.split_raster( - path_to_raster=config["path_to_raster"], - annotations_file=tmpdir.join("blank_annotations.csv").strpath, - save_dir=tmpdir, - patch_size=config["patch_size"], - patch_overlap=config["patch_overlap"], - allow_empty=allow_empty) - assert annotations_file.shape[0] == 4 - assert tmpdir.join("OSBS_029_1.png").exists() + annotations_file = preprocess.split_raster( + path_to_raster=config["path_to_raster"], + annotations_file=tmpdir.join("blank_annotations.csv").strpath, + save_dir=tmpdir, + patch_size=config["patch_size"], + patch_overlap=config["patch_overlap"], + allow_empty=allow_empty) + assert annotations_file.shape[0] == 4 + assert tmpdir.join("OSBS_029_1.png").exists() + def test_split_size_error(config, tmpdir, geodataframe): with pytest.raises(ValueError): @@ -182,6 +187,7 @@ def test_split_raster_4_band_warns(config, tmpdir, orders, geodataframe): patch_overlap=config["patch_overlap"], image_name="OSBS_029.tif") + # Test split_raster with point annotations file def test_split_raster_with_point_annotations(tmpdir, config): # Create a temporary point annotations file @@ -195,11 +201,14 @@ def test_split_raster_with_point_annotations(tmpdir, config): annotations.to_csv(annotations_file, index=False) # Call split_raster function - preprocess.split_raster(annotations_file=annotations_file.strpath, path_to_raster=config["path_to_raster"], save_dir=tmpdir) + preprocess.split_raster(annotations_file=annotations_file.strpath, + path_to_raster=config["path_to_raster"], + save_dir=tmpdir) # Assert that the output annotations file is created assert tmpdir.join("OSBS_029_0.png").exists() + # Test split_raster with box annotations file def test_split_raster_with_box_annotations(tmpdir, config): # Create a temporary box annotations file @@ -215,15 +224,21 @@ def test_split_raster_with_box_annotations(tmpdir, config): annotations.to_csv(annotations_file, index=False) # Call split_raster function - preprocess.split_raster(annotations_file=annotations_file.strpath, path_to_raster=config["path_to_raster"], save_dir=tmpdir) + preprocess.split_raster(annotations_file=annotations_file.strpath, + path_to_raster=config["path_to_raster"], + save_dir=tmpdir) # Assert that the output annotations file is created assert tmpdir.join("OSBS_029_0.png").exists() + # Test split_raster with polygon annotations file def test_split_raster_with_polygon_annotations(tmpdir, config): # Create a temporary polygon annotations file with a polygon in WKT format - sample_geometry = [geometry.Polygon([(0, 0), (0, 2), (1, 1), (1, 0), (0, 0)]), geometry.Polygon([(2, 2), (2, 4), (3, 3), (3, 2), (2, 2)])] + sample_geometry = [ + geometry.Polygon([(0, 0), (0, 2), (1, 1), (1, 0), (0, 0)]), + geometry.Polygon([(2, 2), (2, 4), (3, 3), (3, 2), (2, 2)]) + ] annotations = pd.DataFrame({ "image_path": ["OSBS_029.tif", "OSBS_029.tif"], "polygon": [sample_geometry[0].wkt, sample_geometry[1].wkt], @@ -233,56 +248,68 @@ def test_split_raster_with_polygon_annotations(tmpdir, config): annotations.to_csv(annotations_file, index=False) # Call split_raster function - split_annotations = preprocess.split_raster(annotations_file=annotations_file.strpath, path_to_raster=config["path_to_raster"], save_dir=tmpdir) + split_annotations = preprocess.split_raster(annotations_file=annotations_file.strpath, + path_to_raster=config["path_to_raster"], + save_dir=tmpdir) assert not split_annotations.empty - + # Assert that the output annotations file is created assert tmpdir.join("OSBS_029_0.png").exists() + def test_split_raster_from_csv(tmpdir): - """Read in annotations, convert to a projected shapefile, read back in and crop, the output annotations should still be mantained in logical place""" + """Read in annotations, convert to a projected shapefile, read back in and crop, + the output annotations should still be mantained in logical place""" annotations = get_data("2018_SJER_3_252000_4107000_image_477.csv") path_to_raster = get_data("2018_SJER_3_252000_4107000_image_477.tif") # Check original data - split_annotations = preprocess.split_raster( - annotations_file=annotations, - path_to_raster=path_to_raster, - save_dir=tmpdir, - root_dir=os.path.dirname(path_to_raster), - patch_size=300) + split_annotations = preprocess.split_raster(annotations_file=annotations, + path_to_raster=path_to_raster, + save_dir=tmpdir, + root_dir=os.path.dirname(path_to_raster), + patch_size=300) assert not split_annotations.empty - + # Plot labels - images = visualize.plot_prediction_dataframe(split_annotations, root_dir=tmpdir, savedir=tmpdir) + images = visualize.plot_prediction_dataframe(split_annotations, + root_dir=tmpdir, + savedir=tmpdir) for image in images: im = Image.open(image) im.show() + def test_split_raster_from_shp(tmpdir): annotations = get_data("2018_SJER_3_252000_4107000_image_477.csv") path_to_raster = get_data("2018_SJER_3_252000_4107000_image_477.tif") - gdf = utilities.read_file(annotations) - geo_coords = utilities.image_to_geo_coordinates(gdf, root_dir=os.path.dirname(path_to_raster)) + gdf = utilities.read_file(annotations) + geo_coords = utilities.image_to_geo_coordinates( + gdf, root_dir=os.path.dirname(path_to_raster)) annotations_file = tmpdir.join("projected_annotations.shp").strpath geo_coords.to_file(annotations_file) # Call split_raster function - split_annotations = preprocess.split_raster( - annotations_file=annotations_file, - path_to_raster=path_to_raster, save_dir=tmpdir, root_dir=os.path.dirname(path_to_raster), patch_size=300) - + split_annotations = preprocess.split_raster(annotations_file=annotations_file, + path_to_raster=path_to_raster, + save_dir=tmpdir, + root_dir=os.path.dirname(path_to_raster), + patch_size=300) + assert not split_annotations.empty # Plot labels - images = visualize.plot_prediction_dataframe(split_annotations, root_dir=tmpdir, savedir=tmpdir) + images = visualize.plot_prediction_dataframe(split_annotations, + root_dir=tmpdir, + savedir=tmpdir) for image in images: im = Image.open(image) im.show() + # def test_view_annotation_split(tmpdir, config): # """Test that the split annotations can be visualized and mantain location, turn show to True for debugging interactively""" # annotations = get_data("2019_YELL_2_541000_4977000_image_crop.xml") @@ -294,4 +321,3 @@ def test_split_raster_from_shp(tmpdir): # for image in images: # im = Image.open(image) # im.show() - diff --git a/tests/test_retinanet.py b/tests/test_retinanet.py index ecee00753..fda38bd7d 100644 --- a/tests/test_retinanet.py +++ b/tests/test_retinanet.py @@ -1,4 +1,4 @@ -#test retinanet +# test retinanet from deepforest.models import retinanet from deepforest import get_data import pytest @@ -9,42 +9,49 @@ from torchvision.models import resnet50, ResNet50_Weights from torchvision.models.detection.retinanet import RetinaNet_ResNet50_FPN_Weights +os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' -os.environ['KMP_DUPLICATE_LIB_OK']='True' -#Empty tester from https://github.com/datumbox/vision/blob/06ebee1a9f10c76d8ac5768fd578362dd5ace6e9/test/test_models_detection_negative_samples.py#L14 +# Empty tester from https://github.com/datumbox/vision/blob/06ebee1a9f10c76d8ac5768fd578362dd5ace6e9/test/test_models_detection_negative_samples.py#L14 def _make_empty_sample(): images = [torch.rand((3, 100, 100), dtype=torch.float32)] boxes = torch.zeros((0, 4), dtype=torch.float32) - negative_target = {"boxes": boxes, - "labels": torch.zeros(0, dtype=torch.int64), - "image_id": 4, - "area": (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]), - "iscrowd": torch.zeros((0,), dtype=torch.int64)} + negative_target = { + "boxes": boxes, + "labels": torch.zeros(0, dtype=torch.int64), + "image_id": 4, + "area": (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]), + "iscrowd": torch.zeros((0,), dtype=torch.int64) + } targets = [negative_target] return images, targets + def test_retinanet(config): r = retinanet.Model(config) assert r + def test_load_backbone(config): r = retinanet.Model(config) resnet_backbone = r.load_backbone() resnet_backbone.eval() x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] - prediction = resnet_backbone(x) + prediction = resnet_backbone(x) + -# This test still fails, do we want a way to pass kwargs directly to method, instead of being limited by config structure? +# This test still fails, do we want a way to pass kwargs directly to method, +# instead of being limited by config structure? # Need to create issue when I get online. -@pytest.mark.parametrize("num_classes",[1,2,10]) +@pytest.mark.parametrize("num_classes", [1, 2, 10]) def test_create_model(config, num_classes): config["num_classes"] = num_classes retinanet_model = retinanet.Model(config).create_model() retinanet_model.eval() x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] - predictions = retinanet_model(x) + predictions = retinanet_model(x) + def test_forward_empty(config): r = retinanet.Model(config) @@ -53,6 +60,7 @@ def test_forward_empty(config): loss = model(image, targets) assert torch.equal(loss["bbox_regression"], torch.tensor(0.)) + # Can we update parameters after training def test_mantain_parameters(config): config["retinanet"]["score_thresh"] = 0.4 @@ -60,11 +68,11 @@ def test_mantain_parameters(config): assert retinanet_model.score_thresh == config["retinanet"]["score_thresh"] retinanet_model.eval() x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] - predictions = retinanet_model(x) + predictions = retinanet_model(x) assert retinanet_model.score_thresh == config["retinanet"]["score_thresh"] - + retinanet_model.score_thresh = 0.9 retinanet_model.eval() x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] - predictions = retinanet_model(x) - assert retinanet_model.score_thresh == 0.9 \ No newline at end of file + predictions = retinanet_model(x) + assert retinanet_model.score_thresh == 0.9 diff --git a/tests/test_utilities.py b/tests/test_utilities.py index 394a8b96f..5a97e11c2 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -13,33 +13,36 @@ from deepforest import utilities, visualize from deepforest import main - -#import general model fixture +# import general model fixture from .conftest import download_release import pytest from PIL import Image + @pytest.fixture() def config(): config = utilities.read_config("deepforest_config.yml") return config + def test_read_pascal_voc(): - annotations = utilities.read_pascal_voc( - xml_path=get_data("OSBS_029.xml")) + annotations = utilities.read_pascal_voc(xml_path=get_data("OSBS_029.xml")) print(annotations.shape) assert annotations.shape[0] == 61 + def test_use_release(download_release): # Download latest model from github release release_tag, state_dict = utilities.use_release(check_release=False) + def test_use_bird_release(download_release): # Download latest model from github release release_tag, state_dict = utilities.use_bird_release() - assert os.path.exists(get_data("bird.pt")) - + assert os.path.exists(get_data("bird.pt")) + + def test_float_warning(config): """Users should get a rounding warning when adding annotations with floats""" float_annotations = "tests/data/float_annotations.txt" @@ -47,28 +50,31 @@ def test_float_warning(config): annotations = utilities.read_pascal_voc(float_annotations) assert annotations.xmin.dtype is np.dtype('int64') - + def test_read_file(tmpdir): - sample_geometry = [geometry.Point(404211.9 + 10,3285102 + 20),geometry.Point(404211.9 + 20,3285102 + 20)] - labels = ["Tree","Tree"] - df = pd.DataFrame({"geometry":sample_geometry,"label":labels}) + sample_geometry = [geometry.Point(404211.9 + 10, 3285102 + 20), geometry.Point(404211.9 + 20, 3285102 + 20)] + labels = ["Tree", "Tree"] + df = pd.DataFrame({"geometry": sample_geometry, "label": labels}) gdf = gpd.GeoDataFrame(df, geometry="geometry", crs="EPSG:32617") - gdf["geometry"] = [geometry.box(left, bottom, right, top) for left, bottom, right, top in gdf.geometry.buffer(0.5).bounds.values] + gdf["geometry"] = [geometry.box(left, bottom, right, top) for left, bottom, right, top in + gdf.geometry.buffer(0.5).bounds.values] gdf["image_path"] = get_data("OSBS_029.tif") gdf.to_file("{}/annotations.shp".format(tmpdir)) shp = utilities.read_file(input="{}/annotations.shp".format(tmpdir)) assert shp.shape[0] == 2 + def test_shapefile_to_annotations_convert_unprojected_to_boxes(tmpdir): - sample_geometry = [geometry.Point(10,20),geometry.Point(20,40)] - labels = ["Tree","Tree"] - df = pd.DataFrame({"geometry":sample_geometry,"label":labels}) + sample_geometry = [geometry.Point(10, 20), geometry.Point(20, 40)] + labels = ["Tree", "Tree"] + df = pd.DataFrame({"geometry": sample_geometry, "label": labels}) gdf = gpd.GeoDataFrame(df, geometry="geometry") gdf.to_file("{}/annotations.shp".format(tmpdir)) image_path = get_data("OSBS_029.png") shp = utilities.shapefile_to_annotations(shapefile="{}/annotations.shp".format(tmpdir), rgb=image_path) assert shp.shape[0] == 2 + def test_shapefile_to_annotations_invalid_epsg(tmpdir): sample_geometry = [geometry.Point(404211.9 + 10, 3285102 + 20), geometry.Point(404211.9 + 20, 3285102 + 20)] labels = ["Tree", "Tree"] @@ -79,13 +85,15 @@ def test_shapefile_to_annotations_invalid_epsg(tmpdir): image_path = get_data("OSBS_029.tif") with pytest.raises(ValueError): shp = utilities.shapefile_to_annotations(shapefile="{}/annotations.shp".format(tmpdir), rgb=image_path) - + + def test_read_file_boxes_projected(tmpdir): - sample_geometry = [geometry.Point(404211.9 + 10,3285102 + 20),geometry.Point(404211.9 + 20,3285102 + 20)] - labels = ["Tree","Tree"] - df = pd.DataFrame({"geometry":sample_geometry,"label":labels}) + sample_geometry = [geometry.Point(404211.9 + 10, 3285102 + 20), geometry.Point(404211.9 + 20, 3285102 + 20)] + labels = ["Tree", "Tree"] + df = pd.DataFrame({"geometry": sample_geometry, "label": labels}) gdf = gpd.GeoDataFrame(df, geometry="geometry", crs="EPSG:32617") - gdf["geometry"] = [geometry.box(left, bottom, right, top) for left, bottom, right, top in gdf.geometry.buffer(0.5).bounds.values] + gdf["geometry"] = [geometry.box(left, bottom, right, top) for left, bottom, right, top in + gdf.geometry.buffer(0.5).bounds.values] image_path = get_data("OSBS_029.tif") gdf["image_path"] = image_path gdf.to_file("{}/test_read_file_boxes_projected.shp".format(tmpdir)) @@ -94,20 +102,23 @@ def test_read_file_boxes_projected(tmpdir): shp = utilities.read_file(input="{}/test_read_file_boxes_projected.shp".format(tmpdir)) assert shp.shape[0] == 2 + def test_read_file_points_csv(tmpdir): x = [10, 20] y = [20, 20] labels = ["Tree", "Tree"] - image_path = [get_data("OSBS_029.tif"),get_data("OSBS_029.tif")] + image_path = [get_data("OSBS_029.tif"), get_data("OSBS_029.tif")] df = pd.DataFrame({"x": x, "y": y, "label": labels}) df.to_csv("{}/test_read_file_points.csv".format(tmpdir), index=False) read_df = utilities.read_file(input="{}/test_read_file_points.csv".format(tmpdir)) assert read_df.shape[0] == 2 + def test_read_file_polygons_csv(tmpdir): # Create a sample GeoDataFrame with polygon geometries with 6 points - sample_geometry = [geometry.Polygon([(0, 0), (0, 2), (1, 1), (1, 0), (0, 0)]), geometry.Polygon([(2, 2), (2, 4), (3, 3), (3, 2), (2, 2)])] - + sample_geometry = [geometry.Polygon([(0, 0), (0, 2), (1, 1), (1, 0), (0, 0)]), + geometry.Polygon([(2, 2), (2, 4), (3, 3), (3, 2), (2, 2)])] + labels = ["Tree", "Tree"] image_path = get_data("OSBS_029.png") df = pd.DataFrame({"geometry": sample_geometry, "label": labels, "image_path": os.path.basename(image_path)}) @@ -122,19 +133,21 @@ def test_read_file_polygons_csv(tmpdir): def test_read_file_polygons_projected(tmpdir): - sample_geometry = [geometry.Point(404211.9 + 10,3285102 + 20),geometry.Point(404211.9 + 20,3285102 + 20)] + sample_geometry = [geometry.Point(404211.9 + 10, 3285102 + 20), geometry.Point(404211.9 + 20, 3285102 + 20)] labels = ["Tree", "Tree"] df = pd.DataFrame({"geometry": sample_geometry, "label": labels}) gdf = gpd.GeoDataFrame(df, geometry="geometry", crs="EPSG:32617") - gdf["geometry"] = [geometry.Polygon([(left, bottom), (left, top), (right, top), (right, bottom)]) for left, bottom, right, top in gdf.geometry.buffer(0.5).bounds.values] + gdf["geometry"] = [geometry.Polygon([(left, bottom), (left, top), (right, top), (right, bottom)]) for + left, bottom, right, top in gdf.geometry.buffer(0.5).bounds.values] image_path = get_data("OSBS_029.tif") gdf["image_path"] = image_path gdf.to_file("{}/test_read_file_polygons_projected.shp".format(tmpdir)) shp = utilities.read_file(input="{}/test_read_file_polygons_projected.shp".format(tmpdir)) assert shp.shape[0] == 2 + def test_read_file_points_projected(tmpdir): - sample_geometry = [geometry.Point(404211.9 + 10,3285102 + 20),geometry.Point(404211.9 + 20,3285102 + 20)] + sample_geometry = [geometry.Point(404211.9 + 10, 3285102 + 20), geometry.Point(404211.9 + 20, 3285102 + 20)] labels = ["Tree", "Tree"] df = pd.DataFrame({"geometry": sample_geometry, "label": labels}) gdf = gpd.GeoDataFrame(df, geometry="geometry", crs="EPSG:32617") @@ -145,6 +158,7 @@ def test_read_file_points_projected(tmpdir): assert shp.shape[0] == 2 assert shp.geometry.iloc[0].type == "Point" + def test_read_file_boxes_unprojected(tmpdir): # Create a sample GeoDataFrame with box geometries sample_geometry = [geometry.box(0, 0, 1, 1), geometry.box(2, 2, 3, 3)] @@ -160,6 +174,7 @@ def test_read_file_boxes_unprojected(tmpdir): assert annotations.shape[0] == 2 assert annotations.geometry.iloc[0].type == "Polygon" + def test_read_file_points_unprojected(tmpdir): # Create a sample GeoDataFrame with point geometries sample_geometry = [geometry.Point(0.5, 0.5), geometry.Point(2.5, 2.5)] @@ -176,10 +191,12 @@ def test_read_file_points_unprojected(tmpdir): assert annotations.shape[0] == 2 assert annotations.geometry.iloc[0].type == "Point" + def test_read_file_polygons_unprojected(tmpdir): # Create a sample GeoDataFrame with polygon geometries with 6 points - sample_geometry = [geometry.Polygon([(0, 0), (0, 2), (1, 1), (1, 0), (0, 0)]), geometry.Polygon([(2, 2), (2, 4), (3, 3), (3, 2), (2, 2)])] - + sample_geometry = [geometry.Polygon([(0, 0), (0, 2), (1, 1), (1, 0), (0, 0)]), + geometry.Polygon([(2, 2), (2, 4), (3, 3), (3, 2), (2, 2)])] + labels = ["Tree", "Tree"] df = pd.DataFrame({"geometry": sample_geometry, "label": labels}) gdf = gpd.GeoDataFrame(df, geometry="geometry") @@ -194,10 +211,11 @@ def test_read_file_polygons_unprojected(tmpdir): assert annotations.shape[0] == 2 assert annotations.geometry.iloc[0].type == "Polygon" + def test_crop_raster_valid_crop(tmpdir): rgb_path = get_data("2018_SJER_3_252000_4107000_image_477.tif") raster_bounds = rio.open(rgb_path).bounds - + # Define the bounds for cropping bounds = (raster_bounds[0] + 10, raster_bounds[1] + 10, raster_bounds[0] + 30, raster_bounds[1] + 30) @@ -216,10 +234,11 @@ def test_crop_raster_valid_crop(tmpdir): assert src.count == 3 assert src.dtypes == ("uint8", "uint8", "uint8") + def test_crop_raster_invalid_crop(tmpdir): rgb_path = get_data("2018_SJER_3_252000_4107000_image_477.tif") raster_bounds = rio.open(rgb_path).bounds - + # Define the bounds for cropping bounds = (raster_bounds[0] - 100, raster_bounds[1] - 100, raster_bounds[0] - 30, raster_bounds[1] - 30) @@ -227,12 +246,14 @@ def test_crop_raster_invalid_crop(tmpdir): with pytest.raises(ValueError): result = utilities.crop_raster(bounds, rgb_path=rgb_path, savedir=tmpdir, filename="crop") + def test_crop_raster_no_savedir(tmpdir): rgb_path = get_data("2018_SJER_3_252000_4107000_image_477.tif") raster_bounds = rio.open(rgb_path).bounds - + # Define the bounds for cropping - bounds = (int(raster_bounds[0] + 10), int(raster_bounds[1] + 10), int(raster_bounds[0] + 20), int(raster_bounds[1] + 20)) + bounds = (int(raster_bounds[0] + 10), int(raster_bounds[1] + 10), + int(raster_bounds[0] + 20), int(raster_bounds[1] + 20)) # Call the function under test result = utilities.crop_raster(bounds, rgb_path=rgb_path) @@ -240,6 +261,7 @@ def test_crop_raster_no_savedir(tmpdir): # Assert out is a output numpy array assert isinstance(result, np.ndarray) + def test_crop_raster_png_unprojected(tmpdir): # Define the bounds for cropping bounds = (0, 0, 100, 100) @@ -264,6 +286,7 @@ def test_crop_raster_png_unprojected(tmpdir): # Assert the crs is not present assert src.crs is None + def test_geo_to_image_coordinates_UTM_N(tmpdir): """Read in a csv file, make a projected shapefile, convert to image coordinates and view the results""" annotations = get_data("2018_SJER_3_252000_4107000_image_477.csv") @@ -281,61 +304,69 @@ def test_geo_to_image_coordinates_UTM_N(tmpdir): # geo_coords.plot(ax=ax, color="red") # plt.show() - assert geo_coords[geo_coords.intersects(src_window)].shape[0] == pd.read_csv(annotations).shape[0] + assert geo_coords[geo_coords.intersects(src_window)].shape[0] == pd.read_csv(annotations).shape[0] # Convert to image coordinates - image_coords = utilities.geo_to_image_coordinates(geo_coords, image_bounds=src.bounds, image_resolution=src.res[0]) + image_coords = utilities.geo_to_image_coordinates(geo_coords, image_bounds=src.bounds, image_resolution=src.res[0]) assert image_coords.crs is None - #Confirm overlap + # Confirm overlap numpy_image = src.read() channels, height, width = numpy_image.shape numpy_window = geometry.box(0, 0, width, height) - assert image_coords[image_coords.intersects(numpy_window)].shape[0] == pd.read_csv(annotations).shape[0] + assert image_coords[image_coords.intersects(numpy_window)].shape[0] == pd.read_csv( + annotations).shape[0] - images = visualize.plot_prediction_dataframe(image_coords, root_dir=os.path.dirname(path_to_raster), savedir=tmpdir) + images = visualize.plot_prediction_dataframe(image_coords, + root_dir=os.path.dirname(path_to_raster), + savedir=tmpdir) # Confirm the image coordinates are correct for image in images: im = Image.open(image) im.show() + def test_geo_to_image_coordinates_UTM_S(tmpdir): """Read in a csv file, make a projected shapefile, convert to image coordinates and view the results""" annotations = get_data("australia.shp") path_to_raster = get_data("australia.tif") src = rio.open(path_to_raster) - + geo_coords = gpd.read_file(annotations) src_window = geometry.box(*src.bounds) - #fig, ax = plt.subplots(figsize=(10, 10)) - #gpd.GeoSeries(src_window).plot(ax=ax, color="blue", alpha=0.5) - #geo_coords.plot(ax=ax, color="red") - #plt.show() - - assert geo_coords[geo_coords.intersects(src_window)].shape[0] == gpd.read_file(annotations).shape[0] + # fig, ax = plt.subplots(figsize=(10, 10)) + # gpd.GeoSeries(src_window).plot(ax=ax, color="blue", alpha=0.5) + # geo_coords.plot(ax=ax, color="red") + # plt.show() + + assert geo_coords[geo_coords.intersects(src_window)].shape[0] == gpd.read_file( + annotations).shape[0] # Convert to image coordinates - image_coords = utilities.geo_to_image_coordinates(geo_coords, image_bounds=src.bounds, image_resolution=src.res[0]) + image_coords = utilities.geo_to_image_coordinates(geo_coords, image_bounds=src.bounds, image_resolution=src.res[0]) assert image_coords.crs is None - #Confirm overlap + # Confirm overlap numpy_image = src.read() channels, height, width = numpy_image.shape numpy_window = geometry.box(0, 0, width, height) - assert image_coords[image_coords.intersects(numpy_window)].shape[0] == gpd.read_file(annotations).shape[0] + assert image_coords[image_coords.intersects(numpy_window)].shape[0] == gpd.read_file(annotations).shape[0] - images = visualize.plot_prediction_dataframe(image_coords, root_dir=os.path.dirname(path_to_raster), savedir=tmpdir) + images = visualize.plot_prediction_dataframe(image_coords, + root_dir=os.path.dirname(path_to_raster), + savedir=tmpdir) # Confirm the image coordinates are correct for image in images: im = Image.open(image) im.show() + def test_image_to_geo_coordinates(tmpdir): annotations = get_data("2018_SJER_3_252000_4107000_image_477.csv") path_to_raster = get_data("2018_SJER_3_252000_4107000_image_477.tif") # Convert to image coordinates - gdf = utilities.read_file(annotations) + gdf = utilities.read_file(annotations) images = visualize.plot_prediction_dataframe(gdf, root_dir=os.path.dirname(path_to_raster), savedir=tmpdir) # Confirm it has no crs @@ -345,27 +376,30 @@ def test_image_to_geo_coordinates(tmpdir): for image in images: im = Image.open(image) im.show(title="before") - + # Convert to geo coordinates src = rio.open(path_to_raster) geo_coords = utilities.image_to_geo_coordinates(gdf, root_dir=os.path.dirname(path_to_raster)) src_window = geometry.box(*src.bounds) - assert geo_coords[geo_coords.intersects(src_window)].shape[0] == pd.read_csv(annotations).shape[0] + assert geo_coords[geo_coords.intersects(src_window)].shape[0] == pd.read_csv(annotations).shape[0] # Plot using geopandas - #fig, ax = plt.subplots(figsize=(10, 10)) - #gpd.GeoSeries(src_window).plot(ax=ax, color="blue", alpha=0.5) - #geo_coords.plot(ax=ax, color="red", alpha=0.2) - #show(src, ax=ax) - #plt.show() + # fig, ax = plt.subplots(figsize=(10, 10)) + # gpd.GeoSeries(src_window).plot(ax=ax, color="blue", alpha=0.5) + # geo_coords.plot(ax=ax, color="red", alpha=0.2) + # show(src, ax=ax) + # plt.show() + def test_image_to_geo_coordinates_boxes(tmpdir): annotations = get_data("2018_SJER_3_252000_4107000_image_477.csv") path_to_raster = get_data("2018_SJER_3_252000_4107000_image_477.tif") # Convert to image coordinates - gdf = utilities.read_file(annotations) - images = visualize.plot_prediction_dataframe(gdf, root_dir=os.path.dirname(path_to_raster), savedir=tmpdir) + gdf = utilities.read_file(annotations) + images = visualize.plot_prediction_dataframe(gdf, + root_dir=os.path.dirname(path_to_raster), + savedir=tmpdir) # Confirm it has no crs assert gdf.crs is None @@ -374,28 +408,31 @@ def test_image_to_geo_coordinates_boxes(tmpdir): for image in images: im = Image.open(image) im.show(title="before") - + # Convert to geo coordinates src = rio.open(path_to_raster) geo_coords = utilities.image_to_geo_coordinates(gdf, root_dir=os.path.dirname(path_to_raster)) src_window = geometry.box(*src.bounds) - assert geo_coords[geo_coords.intersects(src_window)].shape[0] == pd.read_csv(annotations).shape[0] + assert geo_coords[geo_coords.intersects(src_window)].shape[0] == pd.read_csv(annotations).shape[0] # Plot using geopandas - #fig, ax = plt.subplots(figsize=(10, 10)) - #gpd.GeoSeries(src_window).plot(ax=ax, color="blue", alpha=0.5) - #geo_coords.plot(ax=ax, color="red", alpha=0.2) - #show(src, ax=ax) - #plt.show() + # fig, ax = plt.subplots(figsize=(10, 10)) + # gpd.GeoSeries(src_window).plot(ax=ax, color="blue", alpha=0.5) + # geo_coords.plot(ax=ax, color="red", alpha=0.2) + # show(src, ax=ax) + # plt.show() + def test_image_to_geo_coordinates_points(tmpdir): annotations = get_data("2018_SJER_3_252000_4107000_image_477.csv") path_to_raster = get_data("2018_SJER_3_252000_4107000_image_477.tif") # Convert to image coordinates - gdf = utilities.read_file(annotations) + gdf = utilities.read_file(annotations) gdf["geometry"] = gdf.geometry.centroid - images = visualize.plot_prediction_dataframe(gdf, root_dir=os.path.dirname(path_to_raster), savedir=tmpdir) + images = visualize.plot_prediction_dataframe(gdf, + root_dir=os.path.dirname(path_to_raster), + savedir=tmpdir) # Confirm it has no crs assert gdf.crs is None @@ -404,29 +441,32 @@ def test_image_to_geo_coordinates_points(tmpdir): for image in images: im = Image.open(image) im.show(title="before") - + # Convert to geo coordinates src = rio.open(path_to_raster) geo_coords = utilities.image_to_geo_coordinates(gdf, root_dir=os.path.dirname(path_to_raster)) src_window = geometry.box(*src.bounds) - assert geo_coords[geo_coords.intersects(src_window)].shape[0] == pd.read_csv(annotations).shape[0] + assert geo_coords[geo_coords.intersects(src_window)].shape[0] == pd.read_csv(annotations).shape[0] # Plot using geopandas - #fig, ax = plt.subplots(figsize=(10, 10)) - #gpd.GeoSeries(src_window).plot(ax=ax, color="blue", alpha=0.5) - #geo_coords.plot(ax=ax, color="red", alpha=0.2) - #show(src, ax=ax) - #plt.show() + # fig, ax = plt.subplots(figsize=(10, 10)) + # gpd.GeoSeries(src_window).plot(ax=ax, color="blue", alpha=0.5) + # geo_coords.plot(ax=ax, color="red", alpha=0.2) + # show(src, ax=ax) + # plt.show() + def test_image_to_geo_coordinates_polygons(tmpdir): annotations = get_data("2018_SJER_3_252000_4107000_image_477.csv") path_to_raster = get_data("2018_SJER_3_252000_4107000_image_477.tif") # Convert to image coordinates - gdf = utilities.read_file(annotations) + gdf = utilities.read_file(annotations) # Skew boxes to make them polygons - gdf["geometry"] = gdf.geometry.skew(7,7) - images = visualize.plot_prediction_dataframe(gdf, root_dir=os.path.dirname(path_to_raster), savedir=tmpdir) + gdf["geometry"] = gdf.geometry.skew(7, 7) + images = visualize.plot_prediction_dataframe(gdf, + root_dir=os.path.dirname(path_to_raster), + savedir=tmpdir) # Confirm it has no crs assert gdf.crs is None @@ -435,19 +475,19 @@ def test_image_to_geo_coordinates_polygons(tmpdir): for image in images: im = Image.open(image) im.show(title="before") - + # Convert to geo coordinates src = rio.open(path_to_raster) geo_coords = utilities.image_to_geo_coordinates(gdf, root_dir=os.path.dirname(path_to_raster)) src_window = geometry.box(*src.bounds) - assert geo_coords[geo_coords.intersects(src_window)].shape[0] == pd.read_csv(annotations).shape[0] + assert geo_coords[geo_coords.intersects(src_window)].shape[0] == pd.read_csv(annotations).shape[0] # Plot using geopandas - #fig, ax = plt.subplots(figsize=(10, 10)) - #gpd.GeoSeries(src_window).plot(ax=ax, color="blue", alpha=0.5) - #geo_coords.plot(ax=ax, color="red", alpha=0.2) - #show(src, ax=ax) - #plt.show() + # fig, ax = plt.subplots(figsize=(10, 10)) + # gpd.GeoSeries(src_window).plot(ax=ax, color="blue", alpha=0.5) + # geo_coords.plot(ax=ax, color="red", alpha=0.2) + # show(src, ax=ax) + # plt.show() def test_boxes_to_shapefile_projected(m): @@ -455,11 +495,11 @@ def test_boxes_to_shapefile_projected(m): r = rio.open(img) df = m.predict_image(path=img) gdf = utilities.boxes_to_shapefile(df, root_dir=os.path.dirname(img), projected=True) - - #Confirm that each boxes within image bounds + + # Confirm that each boxes within image bounds geom = geometry.box(*r.bounds) assert all(gdf.geometry.apply(lambda x: geom.intersects(geom)).values) - - #Edge case, only one row in predictions - gdf = utilities.boxes_to_shapefile(df.iloc[:1,], root_dir=os.path.dirname(img), projected=True) - assert gdf.shape[0] == 1 \ No newline at end of file + + # Edge case, only one row in predictions + gdf = utilities.boxes_to_shapefile(df.iloc[:1, ], root_dir=os.path.dirname(img), projected=True) + assert gdf.shape[0] == 1 diff --git a/tests/test_visualize.py b/tests/test_visualize.py index a72cdc764..9d14c83aa 100644 --- a/tests/test_visualize.py +++ b/tests/test_visualize.py @@ -1,4 +1,4 @@ -#Test visualize +# Test visualize from deepforest import visualize from deepforest import main from deepforest import get_data @@ -18,26 +18,27 @@ def test_format_boxes(m): paths, images, targets = batch for path, image, target in zip(paths, images, targets): target_df = visualize.format_boxes(target, scores=False) - assert list(target_df.columns.values) == ["xmin","ymin","xmax","ymax","label"] + assert list(target_df.columns.values) == ["xmin", "ymin", "xmax", "ymax", "label"] assert not target_df.empty - -#Test different color labels -@pytest.mark.parametrize("label",[0,1,20]) -def test_plot_predictions(m, tmpdir,label): + +# Test different color labels +@pytest.mark.parametrize("label", [0, 1, 20]) +def test_plot_predictions(m, tmpdir, label): ds = m.val_dataloader() batch = next(iter(ds)) paths, images, targets = batch for path, image, target in zip(paths, images, targets): target_df = visualize.format_boxes(target, scores=False) target_df["image_path"] = path - image = np.array(image)[:,:,::-1] - image = np.rollaxis(image,0,3) + image = np.array(image)[:, :, ::-1] + image = np.rollaxis(image, 0, 3) target_df.label = label image = visualize.plot_predictions(image, target_df) assert image.dtype == "uint8" - + + def test_plot_prediction_dataframe(m, tmpdir): ds = m.val_dataloader() batch = next(iter(ds)) @@ -45,21 +46,25 @@ def test_plot_prediction_dataframe(m, tmpdir): for path, image, target in zip(paths, images, targets): target_df = visualize.format_boxes(target, scores=False) target_df["image_path"] = path - filenames = visualize.plot_prediction_dataframe(df=target_df,savedir=tmpdir, root_dir=m.config["validation"]["root_dir"]) - + filenames = visualize.plot_prediction_dataframe( + df=target_df, savedir=tmpdir, root_dir=m.config["validation"]["root_dir"]) + assert all([os.path.exists(x) for x in filenames]) - + + def test_plot_predictions_and_targets(m, tmpdir): ds = m.val_dataloader() batch = next(iter(ds)) paths, images, targets = batch - m.model.eval() - predictions = m.model(images) + m.model.eval() + predictions = m.model(images) for path, image, target, prediction in zip(paths, images, targets, predictions): - image = image.permute(1,2,0) - save_figure_path = visualize.plot_prediction_and_targets(image, prediction, target, image_name=os.path.basename(path), savedir=tmpdir) + image = image.permute(1, 2, 0) + save_figure_path = visualize.plot_prediction_and_targets( + image, prediction, target, image_name=os.path.basename(path), savedir=tmpdir) assert os.path.exists(save_figure_path) + def test_convert_to_sv_format(): # Create a mock DataFrame data = { @@ -72,16 +77,15 @@ def test_convert_to_sv_format(): 'image_path': ['image1.jpg', 'image1.jpg'] } df = pd.DataFrame(data) - + # Call the function detections = visualize.convert_to_sv_format(df) - + # Expected values expected_boxes = np.array([[0, 0, 5, 5], [10, 20, 15, 25]], dtype=np.float32) expected_labels = np.array([0, 0]) expected_scores = np.array([0.9, 0.8]) - # Assertions np.testing.assert_array_equal(detections.xyxy, expected_boxes) np.testing.assert_array_equal(detections.class_id, expected_labels)