Skip to content

Commit

Permalink
inv and black
Browse files Browse the repository at this point in the history
  • Loading branch information
msyyc committed May 22, 2024
1 parent 49ab757 commit 3c22f8e
Show file tree
Hide file tree
Showing 74 changed files with 3,199 additions and 1,206 deletions.
8 changes: 8 additions & 0 deletions .chronus/changes/deserialization-fix-2024-4-22-17-3-4.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
---
changeKind: fix
packages:
- "@autorest/python"
- "@azure-tools/typespec-python"
---

Fix deserialization error for lro when return type has discriminator and succeed in initial response
Original file line number Diff line number Diff line change
Expand Up @@ -902,7 +902,11 @@ def call_request_builder(self, builder: OperationType, is_paging: bool = False)
def deserialize_for_stream_res(self) -> str:
if self.code_model.options["version_tolerant"]:
return "response.iter_bytes()"
return "(await response.load_body()) or response._content # pylint: disable=protected-access" if self.async_mode else f"response.stream_download(self._client.{self.pipeline_name})"
return (
"(await response.load_body()) or response._content # pylint: disable=protected-access"
if self.async_mode
else f"response.stream_download(self._client.{self.pipeline_name})"
)

def response_headers_and_deserialization(
self,
Expand Down Expand Up @@ -981,7 +985,9 @@ def response_headers_and_deserialization(
def handle_error_response(self, builder: OperationType) -> List[str]:
async_await = "await " if self.async_mode else ""
retval = [f"if response.status_code not in {str(builder.success_status_codes)}:"]
need_download = builder.has_stream_kwargs and self.async_mode and not self.code_model.options["version_tolerant"]
need_download = (
builder.has_stream_kwargs and self.async_mode and not self.code_model.options["version_tolerant"]
)
if not self.code_model.need_request_converter or need_download:
load_func = "load_body" if need_download else "read"
retval.extend(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def _basic_polling_initial(self, product: Optional[Union[JSON, IO[bytes]]] = Non
)
_request.url = self._client.format_url(_request.url)

_stream = False
_stream = kwargs.pop("stream", False)
pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access
_request, stream=_stream, **kwargs
)
Expand All @@ -124,10 +124,13 @@ def _basic_polling_initial(self, product: Optional[Union[JSON, IO[bytes]]] = Non

deserialized = None
if response.status_code == 200:
if response.content:
deserialized = response.json()
if _stream:
deserialized = response.iter_bytes()
else:
deserialized = None
if response.content:
deserialized = response.json()
else:
deserialized = None

if cls:
return cls(pipeline_response, deserialized, {}) # type: ignore
Expand Down Expand Up @@ -241,6 +244,7 @@ def begin_basic_polling(
product=product,
content_type=content_type,
cls=lambda x, y, z: x,
stream=True,
headers=_headers,
params=_params,
**kwargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ async def _basic_polling_initial(
)
_request.url = self._client.format_url(_request.url)

_stream = False
_stream = kwargs.pop("stream", False)
pipeline_response: PipelineResponse = await self._client._pipeline.run( # type: ignore # pylint: disable=protected-access
_request, stream=_stream, **kwargs
)
Expand All @@ -97,10 +97,13 @@ async def _basic_polling_initial(

deserialized = None
if response.status_code == 200:
if response.content:
deserialized = response.json()
if _stream:
deserialized = response.iter_bytes()
else:
deserialized = None
if response.content:
deserialized = response.json()
else:
deserialized = None

if cls:
return cls(pipeline_response, deserialized, {}) # type: ignore
Expand Down Expand Up @@ -214,6 +217,7 @@ async def begin_basic_polling(
product=product,
content_type=content_type,
cls=lambda x, y, z: x,
stream=True,
headers=_headers,
params=_params,
**kwargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,21 +148,26 @@ async def _test_lro_initial(
_request = _convert_request(_request)
_request.url = self._client.format_url(_request.url)

_stream = False
_stream = kwargs.pop("stream", False)
pipeline_response: PipelineResponse = await self._client._pipeline.run( # type: ignore # pylint: disable=protected-access
_request, stream=_stream, **kwargs
)

response = pipeline_response.http_response

if response.status_code not in [200, 204]:
if _stream:
await response.load_body() # Load the body in memory and close the socket
map_error(status_code=response.status_code, response=response, error_map=error_map)
error = self._deserialize.failsafe_deserialize(_models.Error, pipeline_response)
raise HttpResponseError(response=response, model=error, error_format=ARMErrorFormat)

deserialized = None
if response.status_code == 200:
deserialized = self._deserialize("Product", pipeline_response)
if _stream:
deserialized = (await response.load_body()) or response._content # pylint: disable=protected-access
else:
deserialized = self._deserialize("Product", pipeline_response)

if cls:
return cls(pipeline_response, deserialized, {}) # type: ignore
Expand Down Expand Up @@ -230,14 +235,18 @@ async def begin_test_lro(
product=product,
content_type=content_type,
cls=lambda x, y, z: x,
stream=True,
headers=_headers,
params=_params,
**kwargs
)
kwargs.pop("error_map", None)

def get_long_running_output(pipeline_response):
deserialized = self._deserialize("Product", pipeline_response)
_response = (
pipeline_response if getattr(pipeline_response, "context", {}) else pipeline_response.http_response
)
deserialized = self._deserialize("Product", _response)
if cls:
return cls(pipeline_response, deserialized, {}) # type: ignore
return deserialized
Expand Down Expand Up @@ -294,18 +303,23 @@ async def _test_lro_and_paging_initial(
_request = _convert_request(_request)
_request.url = self._client.format_url(_request.url)

_stream = False
_stream = kwargs.pop("stream", False)
pipeline_response: PipelineResponse = await self._client._pipeline.run( # type: ignore # pylint: disable=protected-access
_request, stream=_stream, **kwargs
)

response = pipeline_response.http_response

if response.status_code not in [200]:
if _stream:
await response.load_body() # Load the body in memory and close the socket
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response, error_format=ARMErrorFormat)

deserialized = self._deserialize("PagingResult", pipeline_response)
if _stream:
deserialized = (await response.load_body()) or response._content # pylint: disable=protected-access
else:
deserialized = self._deserialize("PagingResult", pipeline_response)

if cls:
return cls(pipeline_response, deserialized, {}) # type: ignore
Expand Down Expand Up @@ -411,6 +425,7 @@ async def get_next(next_link=None):
client_request_id=client_request_id,
test_lro_and_paging_options=test_lro_and_paging_options,
cls=lambda x, y, z: x,
stream=True,
headers=_headers,
params=_params,
**kwargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def _test_lro_initial(
_request = _convert_request(_request)
_request.url = self._client.format_url(_request.url)

_stream = False
_stream = kwargs.pop("stream", False)
pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access
_request, stream=_stream, **kwargs
)
Expand All @@ -239,7 +239,10 @@ def _test_lro_initial(

deserialized = None
if response.status_code == 200:
deserialized = self._deserialize("Product", pipeline_response)
if _stream:
deserialized = response.stream_download(self._client._pipeline)
else:
deserialized = self._deserialize("Product", pipeline_response)

if cls:
return cls(pipeline_response, deserialized, {}) # type: ignore
Expand Down Expand Up @@ -304,14 +307,18 @@ def begin_test_lro(
product=product,
content_type=content_type,
cls=lambda x, y, z: x,
stream=True,
headers=_headers,
params=_params,
**kwargs
)
kwargs.pop("error_map", None)

def get_long_running_output(pipeline_response):
deserialized = self._deserialize("Product", pipeline_response)
_response = (
pipeline_response if getattr(pipeline_response, "context", {}) else pipeline_response.http_response
)
deserialized = self._deserialize("Product", _response)
if cls:
return cls(pipeline_response, deserialized, {}) # type: ignore
return deserialized
Expand Down Expand Up @@ -368,7 +375,7 @@ def _test_lro_and_paging_initial(
_request = _convert_request(_request)
_request.url = self._client.format_url(_request.url)

_stream = False
_stream = kwargs.pop("stream", False)
pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access
_request, stream=_stream, **kwargs
)
Expand All @@ -379,7 +386,10 @@ def _test_lro_and_paging_initial(
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response, error_format=ARMErrorFormat)

deserialized = self._deserialize("PagingResult", pipeline_response)
if _stream:
deserialized = response.stream_download(self._client._pipeline)
else:
deserialized = self._deserialize("PagingResult", pipeline_response)

if cls:
return cls(pipeline_response, deserialized, {}) # type: ignore
Expand Down Expand Up @@ -485,6 +495,7 @@ def get_next(next_link=None):
client_request_id=client_request_id,
test_lro_and_paging_options=test_lro_and_paging_options,
cls=lambda x, y, z: x,
stream=True,
headers=_headers,
params=_params,
**kwargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1608,18 +1608,23 @@ async def _get_multiple_pages_lro_initial(
_request = _convert_request(_request)
_request.url = self._client.format_url(_request.url)

_stream = False
_stream = kwargs.pop("stream", False)
pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access
_request, stream=_stream, **kwargs
)

response = pipeline_response.http_response

if response.status_code not in [202]:
if _stream:
await response.load_body() # Load the body in memory and close the socket
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response, error_format=ARMErrorFormat)

deserialized = self._deserialize("ProductResult", pipeline_response)
if _stream:
deserialized = (await response.load_body()) or response._content # pylint: disable=protected-access
else:
deserialized = self._deserialize("ProductResult", pipeline_response)

if cls:
return cls(pipeline_response, deserialized, {}) # type: ignore
Expand Down Expand Up @@ -1726,6 +1731,7 @@ async def get_next(next_link=None):
client_request_id=client_request_id,
paging_get_multiple_pages_lro_options=paging_get_multiple_pages_lro_options,
cls=lambda x, y, z: x,
stream=True,
headers=_headers,
params=_params,
**kwargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2067,7 +2067,7 @@ def _get_multiple_pages_lro_initial(
_request = _convert_request(_request)
_request.url = self._client.format_url(_request.url)

_stream = False
_stream = kwargs.pop("stream", False)
pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access
_request, stream=_stream, **kwargs
)
Expand All @@ -2078,7 +2078,10 @@ def _get_multiple_pages_lro_initial(
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response, error_format=ARMErrorFormat)

deserialized = self._deserialize("ProductResult", pipeline_response)
if _stream:
deserialized = response.stream_download(self._client._pipeline)
else:
deserialized = self._deserialize("ProductResult", pipeline_response)

if cls:
return cls(pipeline_response, deserialized, {}) # type: ignore
Expand Down Expand Up @@ -2185,6 +2188,7 @@ def get_next(next_link=None):
client_request_id=client_request_id,
paging_get_multiple_pages_lro_options=paging_get_multiple_pages_lro_options,
cls=lambda x, y, z: x,
stream=True,
headers=_headers,
params=_params,
**kwargs,
Expand Down
Loading

0 comments on commit 3c22f8e

Please sign in to comment.