diff --git a/deepforest/dataset.py b/deepforest/dataset.py index cdf648df..72bcd70a 100644 --- a/deepforest/dataset.py +++ b/deepforest/dataset.py @@ -139,7 +139,6 @@ def __getitem__(self, idx): class TileDataset(Dataset): - def __init__(self, tile: typing.Optional[np.ndarray], preload_images: bool = False, diff --git a/deepforest/main.py b/deepforest/main.py index bca9386e..64b4b341 100644 --- a/deepforest/main.py +++ b/deepforest/main.py @@ -126,7 +126,7 @@ def use_bird_release(self, check_release=True): print("Setting default score threshold to 0.3") self.config["score_thresh"] = 0.3 - def create_model(self, backbone=None): + def create_model(self): """Define a deepforest architecture. This can be done in two ways. Passed as the model argument to deepforest __init__(), or as a named architecture in config["architecture"], @@ -138,12 +138,10 @@ def create_model(self, backbone=None): RCNN: nms_thresh: 0.1 etc. - Args: - backbone: """ if self.model is None: model_name = importlib.import_module("deepforest.models.{}".format(self.config["architecture"])) - self.model = model_name.Model(config=self.config, backbone=backbone).create_model() + self.model = model_name.Model(config=self.config).create_model() else: pass diff --git a/deepforest/model.py b/deepforest/model.py index 212d2075..40f03503 100644 --- a/deepforest/model.py +++ b/deepforest/model.py @@ -21,7 +21,7 @@ def __init__(self, config): # Check input output format: self.check_model() - def create_model(backbone=None): + def create_model(): """This function converts a deepforest config file into a model. An architecture should have a list of nested arguments in config that match this function""" raise ValueError("The create_model class method needs to be implemented. Take in args and return a pytorch nn module.") diff --git a/deepforest/models/retinanet.py b/deepforest/models/retinanet.py index 25060b71..aa14a34d 100644 --- a/deepforest/models/retinanet.py +++ b/deepforest/models/retinanet.py @@ -37,7 +37,7 @@ def create_anchor_generator(self, sizes=((8, 16, 32, 64, 128, 256, 400),), return anchor_generator - def create_model(self, backbone=None): + def create_model(self): """Create a retinanet model Args: num_classes (int): number of classes in the model @@ -46,9 +46,8 @@ def create_model(self, backbone=None): Returns: model: a pytorch nn module """ - if backbone is None: - resnet = self.load_backbone() - backbone = resnet.backbone + resnet = self.load_backbone() + backbone = resnet.backbone model = RetinaNet(backbone=backbone, num_classes=self.config["num_classes"]) model.nms_thresh = self.config["nms_thresh"] diff --git a/docs/Model_Architecture.md b/docs/Model_Architecture.md new file mode 100644 index 00000000..ac4dd425 --- /dev/null +++ b/docs/Model_Architecture.md @@ -0,0 +1,79 @@ +# Model Architecture + +DeepForest allows users to specify custom model architectures if they follow certain guidelines. +To create a compliant format, follow the recipe below. + +## Subclass the model.Model() structure + +A subclass is a class instance that inherits the methods and function of super classes. In this cases, model.Model() is defined as: + +``` +# Model - common class +from deepforest.models import * +import torch + +class Model(): + """A architecture agnostic class that controls the basic train, eval and predict functions. + A model should optionally allow a backbone for pretraining. To add new architectures, simply create a new module in models/ and write a create_model. + Then add the result to the if else statement below. + Args: + num_classes (int): number of classes in the model + nms_thresh (float): non-max suppression threshold for intersection-over-union [0,1] + score_thresh (float): minimum prediction score to keep during prediction [0,1] + Returns: + model: a pytorch nn module + """ + def __init__(self, config): + + # Check for required properties and formats + self.config = config + + # Check input output format: + self.check_model() + + def create_model(): + """This function converts a deepforest config file into a model. An architecture should have a list of nested arguments in config that match this function""" + raise ValueError("The create_model class method needs to be implemented. Take in args and return a pytorch nn module.") + + def check_model(self): + """ + Ensure that model follows deepforest guidelines + If fails, raise ValueError + """ + # This assumes model creation is not expensive + test_model = self.create_model() + test_model.eval() + + # Create a dummy batch of 3 band data. + x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] + + predictions = test_model(x) + # Model takes in a batch of images + assert len(predictions) == 2 + + # Returns a list equal to number of images with proper keys per image + model_keys = list(predictions[1].keys()) + model_keys.sort() + assert model_keys == ['boxes','labels','scores'] +``` + +## Match torchvision formats + +From this definition we can see three format requirements. The model must be able to take in a batch of images in the order [channels, height, width]. The current model weights are trained on 3 band images, but you can update the check_model function if you have other image dimensions. +The second requirement is that the model ouputs a dictionary with keys ["boxes","labels","scores"], the boxes are formatted following torchvision object detection format. From the [docs](https://pytorch.org/vision/main/models/generated/torchvision.models.detection.retinanet_resnet50_fpn.html#torchvision.models.detection.retinanet_resnet50_fpn) + +``` +During training, the model expects both the input tensors and targets (list of dictionary), containing: + +boxes (FloatTensor[N, 4]): the ground-truth boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H. + +labels (Int64Tensor[N]): the class label for each ground-truth box + +During inference, the model requires only the input tensors, and returns the post-processed predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as follows, where N is the number of detections: + +boxes (FloatTensor[N, 4]): the predicted boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H. + +labels (Int64Tensor[N]): the predicted labels for each detection + +scores (Tensor[N]): the scores of each detection +``` diff --git a/tests/test_retinanet.py b/tests/test_retinanet.py index f389b2bc..7b10df68 100644 --- a/tests/test_retinanet.py +++ b/tests/test_retinanet.py @@ -6,6 +6,9 @@ import torch import torchvision import os +from torchvision.models import resnet50, ResNet50_Weights +from torchvision.models.detection.retinanet import RetinaNet_ResNet50_FPN_Weights + os.environ['KMP_DUPLICATE_LIB_OK']='True'