Skip to content

Commit

Permalink
Drop training argument from fit_class_random_forest
Browse files Browse the repository at this point in the history
  • Loading branch information
soxofaan committed Mar 29, 2022
1 parent 696f52b commit c42d3c9
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 14 deletions.
8 changes: 1 addition & 7 deletions openeo_driver/ProcessGraphDeserializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,12 +734,6 @@ def fit_class_random_forest(args: dict, env: EvalEnv) -> DriverMlModel:

# TODO: get defaults from process spec?
# TODO: do parameter checks automatically based on process spec?
training = extract_arg(args, 'training')
if not isinstance(training, float) or training < 0.0 or training > 1.0:
raise ProcessParameterInvalidException(
parameter="training", process="fit_class_random_forest",
reason="should be a float between 0 and 1."
)
num_trees = args.get("num_trees", 100)
if not isinstance(num_trees, int) or num_trees < 0:
raise ProcessParameterInvalidException(
Expand All @@ -760,7 +754,7 @@ def fit_class_random_forest(args: dict, env: EvalEnv) -> DriverMlModel:
)

return predictors.fit_class_random_forest(
target=target, training=training,
target=target,
num_trees=num_trees, mtry=mtry, seed=seed,
)

Expand Down
2 changes: 1 addition & 1 deletion openeo_driver/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.17.5a1'
__version__ = '0.17.6a1'
5 changes: 2 additions & 3 deletions openeo_driver/dummy/dummy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,13 +252,12 @@ def prepare_for_json(self):
return self.data

def fit_class_random_forest(
self, target: dict,
training: float, num_trees: int, mtry: Optional[int] = None, seed: Optional[int] = None
self, target: dict, num_trees: int, mtry: Optional[int] = None, seed: Optional[int] = None
) -> DriverMlModel:
# Fake ML training: just store inputs
return DummyMlModel(
process_id="fit_class_random_forest",
data=self.data, target=target, training=training, num_trees=num_trees, mtry=mtry, seed=seed,
data=self.data, target=target, num_trees=num_trees, mtry=mtry, seed=seed,
)


Expand Down
3 changes: 1 addition & 2 deletions openeo_driver/save_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,8 +532,7 @@ def write_assets(self, directory: str) -> Dict[str, StacAsset]:
return {str(Path(filename).name): asset}

def fit_class_random_forest(
self, target: dict,
training: float, num_trees: int, mtry: Optional[int] = None, seed: Optional[int] = None
self, target: dict, num_trees: int, mtry: Optional[int] = None, seed: Optional[int] = None
) -> DriverMlModel:
# TODO: this method belongs eventually under DriverVectorCube
raise NotImplementedError
Expand Down
1 change: 0 additions & 1 deletion tests/test_views_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -2378,7 +2378,6 @@ def test_fit_class_random_forest(api100):
"process_id": "fit_class_random_forest",
"data": [[100.0, 100.1, 100.2, 100.3], [101.0, 101.1, 101.2, 101.3]],
"target": DictSubSet({"type": "FeatureCollection"}),
"training": 0.5,
"num_trees": 200,
"mtry": None,
"seed": None,
Expand Down

0 comments on commit c42d3c9

Please sign in to comment.