Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle B3 trace_id and span_id correctly #934

Merged
merged 8 commits into from
Aug 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
# limitations under the License.

import typing
from re import compile as re_compile

import opentelemetry.trace as trace
from opentelemetry.context import Context
from opentelemetry.sdk.trace import generate_span_id, generate_trace_id
from opentelemetry.trace.propagation.httptextformat import (
Getter,
HTTPTextFormat,
Expand All @@ -37,6 +39,8 @@ class B3Format(HTTPTextFormat):
SAMPLED_KEY = "x-b3-sampled"
FLAGS_KEY = "x-b3-flags"
_SAMPLE_PROPAGATE_VALUES = set(["1", "True", "true", "d"])
_trace_id_regex = re_compile(r"[\da-fA-F]{16}|[\da-fA-F]{32}")
_span_id_regex = re_compile(r"[\da-fA-F]{16}")

def extract(
self,
Expand Down Expand Up @@ -95,19 +99,32 @@ def extract(
or flags
)

if (
self._trace_id_regex.fullmatch(trace_id) is None
or self._span_id_regex.fullmatch(span_id) is None
):
trace_id = generate_trace_id()
span_id = generate_span_id()
ocelotl marked this conversation as resolved.
Show resolved Hide resolved
sampled = "0"

else:
trace_id = int(trace_id, 16)
span_id = int(span_id, 16)

options = 0
# The b3 spec provides no defined behavior for both sample and
# flag values set. Since the setting of at least one implies
# the desire for some form of sampling, propagate if either
# header is set to allow.
if sampled in self._SAMPLE_PROPAGATE_VALUES or flags == "1":
options |= trace.TraceFlags.SAMPLED

return trace.set_span_in_context(
trace.DefaultSpan(
trace.SpanContext(
# trace an span ids are encoded in hex, so must be converted
trace_id=int(trace_id, 16),
span_id=int(span_id, 16),
trace_id=trace_id,
span_id=span_id,
is_remote=True,
trace_flags=trace.TraceFlags(options),
trace_state=trace.TraceState(),
Expand Down
45 changes: 45 additions & 0 deletions opentelemetry-sdk/tests/trace/propagation/test_b3_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import unittest
from unittest.mock import patch

import opentelemetry.sdk.trace as trace
import opentelemetry.sdk.trace.propagation.b3_format as b3_format
Expand Down Expand Up @@ -245,6 +246,50 @@ def test_missing_trace_id(self):
span_context = trace_api.get_current_span(ctx).get_context()
self.assertEqual(span_context.trace_id, trace_api.INVALID_TRACE_ID)

@patch("opentelemetry.sdk.trace.propagation.b3_format.generate_trace_id")
@patch("opentelemetry.sdk.trace.propagation.b3_format.generate_span_id")
def test_invalid_trace_id(
self, mock_generate_span_id, mock_generate_trace_id
):
"""If a trace id is invalid, generate a trace id."""

mock_generate_trace_id.configure_mock(return_value=1)
mock_generate_span_id.configure_mock(return_value=2)

carrier = {
FORMAT.TRACE_ID_KEY: "abc123",
FORMAT.SPAN_ID_KEY: self.serialized_span_id,
FORMAT.FLAGS_KEY: "1",
}

ctx = FORMAT.extract(get_as_list, carrier)
span_context = trace_api.get_current_span(ctx).get_context()

self.assertEqual(span_context.trace_id, 1)
self.assertEqual(span_context.span_id, 2)

@patch("opentelemetry.sdk.trace.propagation.b3_format.generate_trace_id")
@patch("opentelemetry.sdk.trace.propagation.b3_format.generate_span_id")
def test_invalid_span_id(
self, mock_generate_span_id, mock_generate_trace_id
):
"""If a span id is invalid, generate a trace id."""

mock_generate_trace_id.configure_mock(return_value=1)
mock_generate_span_id.configure_mock(return_value=2)

carrier = {
FORMAT.TRACE_ID_KEY: self.serialized_trace_id,
FORMAT.SPAN_ID_KEY: "abc123",
FORMAT.FLAGS_KEY: "1",
}

ctx = FORMAT.extract(get_as_list, carrier)
span_context = trace_api.get_current_span(ctx).get_context()

self.assertEqual(span_context.trace_id, 1)
self.assertEqual(span_context.span_id, 2)

def test_missing_span_id(self):
"""If a trace id is missing, populate an invalid trace id."""
carrier = {
Expand Down