Skip to content

Commit

Permalink
pydantic>2 (#83)
Browse files Browse the repository at this point in the history
  • Loading branch information
clausmichele committed May 27, 2024
1 parent cbcbe85 commit 8c1af9e
Show file tree
Hide file tree
Showing 8 changed files with 191 additions and 212 deletions.
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
default_language_version:
python: python3
repos:
- repo: https://github.com/asottile/pyupgrade
rev: v3.10.1
Expand Down
5 changes: 4 additions & 1 deletion openeo_pg_parser_networkx/graph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations

import sys

sys.setrecursionlimit(16385) # Necessary when parsing really big graphs
import functools
import json
import logging
Expand Down Expand Up @@ -110,7 +113,7 @@ def _parse_datamodel(nested_graph: dict) -> ProcessGraph:
Parses a nested process graph into the Pydantic datamodel for ProcessGraph.
"""

return ProcessGraph.parse_obj(nested_graph)
return ProcessGraph.model_validate(nested_graph)

def _parse_process_graph(self, process_graph: ProcessGraph, arg_name: str = None):
"""
Expand Down
162 changes: 87 additions & 75 deletions openeo_pg_parser_networkx/pg_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
from enum import Enum
from re import match
from typing import Any, Optional, Union
from typing import Annotated, Any, List, Optional, Union
from uuid import UUID, uuid4

import numpy as np
Expand All @@ -22,9 +22,13 @@
BaseModel,
Extra,
Field,
RootModel,
StringConstraints,
ValidationError,
conlist,
constr,
field_validator,
model_validator,
validator,
)
from shapely.geometry import Polygon
Expand Down Expand Up @@ -65,13 +69,14 @@ class ParameterReference(BaseModel, extra=Extra.forbid):


class ProcessNode(BaseModel, arbitrary_types_allowed=True):
process_id: constr(regex=r'^\w+$')
process_id: Annotated[str, StringConstraints(pattern=r'^\w+$')]

namespace: Optional[Optional[str]] = None
result: Optional[bool] = False
description: Optional[Optional[str]] = None
arguments: dict[
str,
Optional[
Annotated[
Union[
ResultReference,
ParameterReference,
Expand All @@ -87,11 +92,12 @@ class ProcessNode(BaseModel, arbitrary_types_allowed=True):
# GeoJson, disable while https://github.com/developmentseed/geojson-pydantic/issues/92 is open
Time,
float,
str,
bool,
list,
dict,
]
str,
],
Field(union_mode='left_to_right'),
],
]

Expand Down Expand Up @@ -133,9 +139,9 @@ class BoundingBox(BaseModel, arbitrary_types_allowed=True):
east: float
north: float
south: float
base: Optional[float]
height: Optional[float]
crs: Optional[Union[str, int]]
base: Optional[float] = None
height: Optional[float] = None
crs: Optional[Union[str, int]] = None

# validators
_parse_crs: classmethod = crs_validator('crs')
Expand All @@ -153,48 +159,54 @@ def polygon(self) -> Polygon:
)


class Date(BaseModel):
__root__: datetime.datetime
class Date(RootModel):
root: datetime.datetime

@validator("__root__", pre=True)
@field_validator("root", mode="before")
def validate_time(cls, value: Any) -> Any:
if (
isinstance(value, str)
and len(value) <= 11
and match(r"[0-9]{4}[-/][0-9]{2}[-/][0-9]{2}T?", value)
):
return pendulum.parse(value)
raise ValidationError("Could not parse `Date` from input.")
raise ValueError("Could not parse `Date` from input.")

def to_numpy(self):
return np.datetime64(self.__root__)
return np.datetime64(self.root)

def __repr__(self):
return self.__root__.__repr__()
return self.root.__repr__()

def __gt__(self, date1):
return self.root > date1.root


class DateTime(BaseModel):
__root__: datetime.datetime
class DateTime(RootModel):
root: datetime.datetime

@validator("__root__", pre=True)
@field_validator("root", mode="before")
def validate_time(cls, value: Any) -> Any:
if isinstance(value, str) and match(
r"[0-9]{4}-[0-9]{2}-[0-9]{2}T?[0-9]{2}:[0-9]{2}:?([0-9]{2})?Z?", value
):
return pendulum.parse(value)
raise ValidationError("Could not parse `DateTime` from input.")
raise ValueError("Could not parse `DateTime` from input.")

def to_numpy(self):
return np.datetime64(self.__root__)
return np.datetime64(self.root)

def __repr__(self):
return self.__root__.__repr__()
return self.root.__repr__()

def __gt__(self, date1):
return self.root > date1.root


class Time(BaseModel):
__root__: pendulum.Time
class Time(RootModel):
root: datetime.time

@validator("__root__", pre=True)
@field_validator("root", mode="before")
def validate_time(cls, value: Any) -> Any:
if (
isinstance(value, str)
Expand All @@ -203,145 +215,145 @@ def validate_time(cls, value: Any) -> Any:
and match(r"[0-9]{2}:[0-9]{2}:?([0-9]{2})?Z?", value)
):
return pendulum.parse(value).time()
raise ValidationError("Could not parse `Time` from input.")
raise ValueError("Could not parse `Time` from input.")

def to_numpy(self):
raise NotImplementedError

def __repr__(self):
return self.__root__.__repr__()
return self.time.__repr__()


class Year(BaseModel):
__root__: datetime.datetime
class Year(RootModel):
root: datetime.datetime

@validator("__root__", pre=True)
@field_validator("root", mode="before")
def validate_time(cls, value: Any) -> Any:
if isinstance(value, str) and len(value) <= 4 and match(r"^\d{4}$", value):
return pendulum.parse(value)
raise ValidationError("Could not parse `Year` from input.")
raise ValueError("Could not parse `Year` from input.")

def to_numpy(self):
return np.datetime64(self.__root__)
return np.datetime64(self.root)

def __repr__(self):
return self.__root__.__repr__()
return self.root.__repr__()


class Duration(BaseModel):
__root__: datetime.timedelta
class Duration(RootModel):
root: datetime.timedelta

@validator("__root__", pre=True)
@field_validator("root", mode="before")
def validate_time(cls, value: Any) -> Any:
if isinstance(value, str) and match(
r"P[0-9]*Y?[0-9]*M?[0-9]*D?T?[0-9]*H?[0-9]*M?[0-9]*S?", value
):
return pendulum.parse(value).as_timedelta()
raise ValidationError("Could not parse `Duration` from input.")
raise ValueError("Could not parse `Duration` from input.")

def to_numpy(self):
return np.timedelta64(self.__root__)
return np.timedelta64(self.root)

def __repr__(self):
return self.__root__.__repr__()
return self.root.__repr__()


class TemporalInterval(BaseModel):
__root__: conlist(Union[Year, Date, DateTime, Time, None], min_items=2, max_items=2)
class TemporalInterval(RootModel):
root: conlist(Union[Year, Date, DateTime, Time, None], min_length=2, max_length=2)

@validator("__root__")
@field_validator("root")
def validate_temporal_interval(cls, value: Any) -> Any:
start = value[0]
end = value[1]

if start is None and end is None:
raise ValidationError("Could not parse `TemporalInterval` from input.")
raise ValueError("Could not parse `TemporalInterval` from input.")

# Disambiguate the Time subtype
if isinstance(start, Time) or isinstance(end, Time):
if isinstance(start, Time) and isinstance(end, Time):
raise ValidationError(
raise ValueError(
"Ambiguous TemporalInterval, both start and end are of type `Time`"
)
if isinstance(start, Time):
if end is None:
raise ValidationError(
raise ValueError(
"Cannot disambiguate TemporalInterval, start is `Time` and end is `None`"
)
logger.warning(
"Start time of temporal interval is of type `time`. Assuming same date as the end time."
)
start = DateTime(
__root__=pendulum.datetime(
end.__root__.year,
end.__root__.month,
end.__root__.day,
start.__root__.hour,
start.__root__.minute,
start.__root__.second,
start.__root__.microsecond,
root=pendulum.datetime(
end.root.year,
end.root.month,
end.root.day,
start.root.hour,
start.root.minute,
start.root.second,
start.root.microsecond,
).to_rfc3339_string()
)
elif isinstance(end, Time):
if start is None:
raise ValidationError(
raise ValueError(
"Cannot disambiguate TemporalInterval, start is `None` and end is `Time`"
)
logger.warning(
"End time of temporal interval is of type `time`. Assuming same date as the start time."
)
end = DateTime(
__root__=pendulum.datetime(
start.__root__.year,
start.__root__.month,
start.__root__.day,
end.__root__.hour,
end.__root__.minute,
end.__root__.second,
end.__root__.microsecond,
root=pendulum.datetime(
start.root.year,
start.root.month,
start.root.day,
end.root.hour,
end.root.minute,
end.root.second,
end.root.microsecond,
).to_rfc3339_string()
)

if not (start is None or end is None) and start.__root__ > end.__root__:
raise ValidationError("Start time > end time")
if not (start is None or end is None) and start > end:
raise ValueError("Start time > end time")

return [start, end]

@property
def start(self):
return self.__root__[0]
return self.root[0]

@property
def end(self):
return self.__root__[1]
return self.root[1]

def __iter__(self):
return iter(self.__root__)
return iter(self.root)

def __getitem__(self, item):
return self.__root__[item]
return self.root[item]


class TemporalIntervals(BaseModel):
__root__: list[TemporalInterval]
class TemporalIntervals(RootModel):
root: list[TemporalInterval]

def __iter__(self):
return iter(self.__root__)
return iter(self.root)

def __getitem__(self, item) -> TemporalInterval:
return self.__root__[item]
return self.root[item]


GeoJson = Union[FeatureCollection, Feature, GeometryCollection, MultiPolygon, Polygon]
# The GeoJson spec (https://www.rfc-editor.org/rfc/rfc7946.html#ref-GJ2008) doesn't
# have a crs field anymore and recommends assuming it to be EPSG:4326, so we do the same.


class JobId(BaseModel):
__root__: str = Field(
regex=r"(eodc-jb-|jb-)[a-z0-9]{8}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{12}"
class JobId(RootModel):
root: str = Field(
pattern=r"(eodc-jb-|jb-)[a-z0-9]{8}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{12}"
)


ResultReference.update_forward_refs()
ProcessNode.update_forward_refs()
ResultReference.model_rebuild()
ProcessNode.model_rebuild()
4 changes: 2 additions & 2 deletions openeo_pg_parser_networkx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
def parse_nested_parameter(parameter: Any):
try:
return ResultReference.parse_obj(parameter)
except pydantic.error_wrappers.ValidationError:
except pydantic.ValidationError:
pass
except TypeError:
pass

try:
return ParameterReference.parse_obj(parameter)
except pydantic.error_wrappers.ValidationError:
except pydantic.ValidationError:
pass
except TypeError:
pass
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ packages = [

[tool.poetry.dependencies]
python = ">=3.9,<3.12"
pydantic = "^1.9.1"
pydantic = "^2.4.0"
pyproj = "^3.4.0"
networkx = "^2.8.6"
shapely = ">=1.8"
geojson-pydantic = "^0.5.0"
geojson-pydantic = "^1.0.0"
numpy = "^1.20.3"
pendulum = "^2.1.2"
matplotlib = { version = "^3.7.1", optional = true }
Expand Down
Loading

0 comments on commit 8c1af9e

Please sign in to comment.