Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix/set context and feature usage #145

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/msgraph_core/graph_client_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@

import httpx
from kiota_http.kiota_client_factory import KiotaClientFactory
from kiota_http.middleware import AsyncKiotaTransport
from kiota_http.middleware.middleware import BaseMiddleware

from ._enums import APIVersion, NationalClouds
from .middleware import GraphTelemetryHandler
from .middleware import AsyncGraphTransport, GraphTelemetryHandler


class GraphClientFactory(KiotaClientFactory):
Expand Down Expand Up @@ -40,9 +39,10 @@ def create_with_default_middleware(
middleware, current_transport
)

client._transport = AsyncKiotaTransport(
client._transport = AsyncGraphTransport(
transport=current_transport, pipeline=middleware_pipeline
)
client._transport.pipeline
return client

@staticmethod
Expand All @@ -66,7 +66,7 @@ def create_with_custom_middleware(
middleware, current_transport
)

client._transport = AsyncKiotaTransport(
client._transport = AsyncGraphTransport(
transport=current_transport, pipeline=middleware_pipeline
)
return client
Expand Down
1 change: 1 addition & 0 deletions src/msgraph_core/middleware/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
from .async_graph_transport import AsyncGraphTransport
from .request_context import GraphRequestContext
from .telemetry import GraphTelemetryHandler
44 changes: 44 additions & 0 deletions src/msgraph_core/middleware/async_graph_transport.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import json

import httpx
from kiota_http.middleware import MiddlewarePipeline, RedirectHandler, RetryHandler

from .._enums import FeatureUsageFlag
from .request_context import GraphRequestContext


class AsyncGraphTransport(httpx.AsyncBaseTransport):
"""A custom transport for requests to the Microsoft Graph API
"""

def __init__(self, transport: httpx.AsyncBaseTransport, pipeline: MiddlewarePipeline) -> None:
self.transport = transport
self.pipeline = pipeline

async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
if self.pipeline:
self.set_request_context_and_feature_usage(request)
response = await self.pipeline.send(request)
return response

response = await self.transport.handle_async_request(request)
return response

def set_request_context_and_feature_usage(self, request: httpx.Request) -> httpx.Request:

request_options = {}
options = request.headers.get('request_options', None)
if options:
request_options = json.loads(options)

context = GraphRequestContext(request_options, request.headers)
middleware = self.pipeline._first_middleware
while middleware:
if isinstance(middleware, RedirectHandler):
context.feature_usage = FeatureUsageFlag.REDIRECT_HANDLER_ENABLED
if isinstance(middleware, RetryHandler):
context.feature_usage = FeatureUsageFlag.RETRY_HANDLER_ENABLED

middleware = middleware.next
request.context = context #type: ignore
return request
29 changes: 4 additions & 25 deletions src/msgraph_core/middleware/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import platform

import httpx
from kiota_http.middleware import AsyncKiotaTransport, BaseMiddleware, RedirectHandler, RetryHandler
from kiota_http.middleware import BaseMiddleware
from urllib3.util import parse_url

from .._constants import SDK_VERSION
from .._enums import FeatureUsageFlag, NationalClouds
from .._enums import NationalClouds
from .async_graph_transport import AsyncGraphTransport
from .request_context import GraphRequestContext


Expand All @@ -20,10 +21,9 @@ class GraphTelemetryHandler(BaseMiddleware):
the SDK team improve the developer experience.
"""

async def send(self, request: GraphRequest, transport: AsyncKiotaTransport):
async def send(self, request: GraphRequest, transport: AsyncGraphTransport):
"""Adds telemetry headers and sends the http request.
"""
self.set_request_context_and_feature_usage(request, transport)

if self.is_graph_url(request.url):
self._add_client_request_id_header(request)
Expand All @@ -34,27 +34,6 @@ async def send(self, request: GraphRequest, transport: AsyncKiotaTransport):
response = await super().send(request, transport)
return response

def set_request_context_and_feature_usage(
self, request: GraphRequest, transport: AsyncKiotaTransport
) -> GraphRequest:

request_options = {}
options = request.headers.pop('request_options', None)
if options:
request_options = json.loads(options)

request.context = GraphRequestContext(request_options, request.headers)
middleware = transport.pipeline._first_middleware
while middleware:
if isinstance(middleware, RedirectHandler):
request.context.feature_usage = FeatureUsageFlag.REDIRECT_HANDLER_ENABLED
if isinstance(middleware, RetryHandler):
request.context.feature_usage = FeatureUsageFlag.RETRY_HANDLER_ENABLED

middleware = middleware.next

return request

def is_graph_url(self, url):
"""Check if the request is made to a graph endpoint. We do not add telemetry headers to
non-graph endpoints"""
Expand Down
18 changes: 18 additions & 0 deletions tests/unit/test_async_graph_transport.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pytest
from kiota_http.kiota_client_factory import KiotaClientFactory

from msgraph_core._enums import FeatureUsageFlag
from msgraph_core.middleware import AsyncGraphTransport, GraphRequestContext


def test_set_request_context_and_feature_usage(mock_request, mock_transport):
middleware = KiotaClientFactory.get_default_middleware()
pipeline = KiotaClientFactory.create_middleware_pipeline(middleware, mock_transport)
transport = AsyncGraphTransport(mock_transport, pipeline)
transport.set_request_context_and_feature_usage(mock_request)

assert hasattr(mock_request, 'context')
assert isinstance(mock_request.context, GraphRequestContext)
assert mock_request.context.feature_usage == hex(
FeatureUsageFlag.RETRY_HANDLER_ENABLED | FeatureUsageFlag.REDIRECT_HANDLER_ENABLED
)
8 changes: 4 additions & 4 deletions tests/unit/test_graph_client_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@
# ------------------------------------
import httpx
import pytest
from kiota_http.middleware import AsyncKiotaTransport, MiddlewarePipeline, RedirectHandler
from kiota_http.middleware import MiddlewarePipeline, RedirectHandler

from msgraph_core import APIVersion, GraphClientFactory, NationalClouds
from msgraph_core.middleware.telemetry import GraphTelemetryHandler
from msgraph_core.middleware import AsyncGraphTransport, GraphTelemetryHandler


def test_create_with_default_middleware():
"""Test creation of GraphClient using default middleware"""
client = GraphClientFactory.create_with_default_middleware()

assert isinstance(client, httpx.AsyncClient)
assert isinstance(client._transport, AsyncKiotaTransport)
assert isinstance(client._transport, AsyncGraphTransport)
pipeline = client._transport.pipeline
assert isinstance(pipeline, MiddlewarePipeline)
assert isinstance(pipeline._first_middleware, RedirectHandler)
Expand All @@ -30,7 +30,7 @@ def test_create_with_custom_middleware():
client = GraphClientFactory.create_with_custom_middleware(middleware=middleware)

assert isinstance(client, httpx.AsyncClient)
assert isinstance(client._transport, AsyncKiotaTransport)
assert isinstance(client._transport, AsyncGraphTransport)
pipeline = client._transport.pipeline
assert isinstance(pipeline._first_middleware, GraphTelemetryHandler)

Expand Down
11 changes: 0 additions & 11 deletions tests/unit/test_graph_telemetry_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,11 @@
import pytest

from msgraph_core import SDK_VERSION, APIVersion, NationalClouds
from msgraph_core._enums import FeatureUsageFlag
from msgraph_core.middleware import GraphRequestContext, GraphTelemetryHandler

BASE_URL = NationalClouds.Global + '/' + APIVersion.v1


def test_set_request_context_and_feature_usage(mock_request, mock_transport):
telemetry_handler = GraphTelemetryHandler()
telemetry_handler.set_request_context_and_feature_usage(mock_request, mock_transport)

assert hasattr(mock_request, 'context')
assert mock_request.context.feature_usage == hex(
FeatureUsageFlag.RETRY_HANDLER_ENABLED | FeatureUsageFlag.REDIRECT_HANDLER_ENABLED
)


def test_is_graph_url(mock_graph_request):
"""
Test method that checks whether a request url is a graph endpoint
Expand Down