Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more Prometheus metrics #2764

Merged
merged 50 commits into from
Apr 28, 2024
Merged

Add more Prometheus metrics #2764

merged 50 commits into from
Apr 28, 2024

Conversation

ronensc
Copy link
Contributor

@ronensc ronensc commented Feb 5, 2024

This PR adds the Prometheus metrics defined in #2650

@ronensc ronensc changed the title Title: Add more Prometheus metrics Add more Prometheus metrics Feb 5, 2024
vllm/engine/metrics.py Outdated Show resolved Hide resolved
vllm:request_max_tokens -> vllm:request_params_max_tokens
vllm:request_n -> vllm:request_params_n
@ronensc
Copy link
Contributor Author

ronensc commented Feb 12, 2024

@rib-2, I highly value your opinion. Would you please review my pull request?

@simon-mo
Copy link
Collaborator

@ronensc
Copy link
Contributor Author

ronensc commented Mar 18, 2024

@simon-mo Could you please review this PR?

Comment on lines 67 to 78
self.histogram_request_prompt_tokens = Histogram(
name="vllm:request_prompt_tokens",
documentation="Number of prefill tokens processed.",
labelnames=labelnames,
buckets=build_1_2_5_buckets(max_model_len),
)
self.histogram_request_generation_tokens = Histogram(
name="vllm:request_generation_tokens",
documentation="Number of generation tokens processed.",
labelnames=labelnames,
buckets=build_1_2_5_buckets(max_model_len),
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two could be constructed using vllm:prompt_tokens_total and vllm:generation_tokens_total using a Binary operation transform in Grafana.

It wouldn't be exactly the same, but it would prevent additional overhead in the server. i.e. if you calculate it on grafana (and your scrape interval is 1 minute) then it'd be a histogram of how many tokens get processed/generated per minute rather than how many tokens get processed/generated per request.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your feedback!
You are right. But, wouldn't it be beneficial to have in addition histograms depicting the distribution of prompt length and generation length?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this metric doesn't actually introduce any overhead (because the data from vllm:x_tokens_total is reused, these two are probably fine. It would be interesting to know how big the prompts the users were providing are.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly! I suggest to deprecate the 2 vllm:x_tokens_total metrics as they will be included as part of the Histogram metrics this PR adds.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should keep these metrics, because a developer may not want to have to aggregate histogram data in order to get the same effect of vllm:x_tokens_total

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prometheus histograms have this nice feature where in addition to the bucket counters, they include 2 additional counters
suffixed with _sum and _count.

_count is incremented by 1 on every observe, and _sum is incremented by the value of the observation.

Therfore, vllm:prompt_tokens_total is equivalent to vllm:request_prompt_tokens_sum,
and vllm:generation_tokens_total is equivalent to vllm:request_generation_tokens_sum

Source:
https://www.robustperception.io/how-does-a-prometheus-histogram-work/

Copy link
Collaborator

@hmellor hmellor Mar 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see, thanks for explaining. In that case you could move the vllm:x_tokens_total metrics into the # Legacy metrics section.

Although I think there might be some objection to changing metrics that people are already using in dashboards.

cc @simon-mo @Yard1 @robertgshaw2-neuralmagic (not sure who to ping for metrics related things, so please tell me if I should stop)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from my point of view, it's fine to duplicate metrics for backward compatibility reason.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'll relocate these metrics to the legacy section. Perhaps in the future, when we're able to make breaking changes, we can consider removing them.

Comment on lines 63 to 66
self.counter_request_success = Counter(
name="vllm:request_success",
documentation="Count of successfully processed requests.",
labelnames=labelnames)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't just counting successful responses, it's counting all finish reasons. If you could find an elegant way to implement the counters we lost when switching from aioprometheus to prometheus_client, that would be great!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A quick option to add http related metrics would be to use prometheus-fastapi-instrumentator.

This involves installing the package:
pip install prometheus-fastapi-instrumentator

Then adding the following 2 lines after the app creation:

app = fastapi.FastAPI(lifespan=lifespan)

from prometheus_fastapi_instrumentator import Instrumentator
Instrumentator().instrument(app).expose(app)

This will add the following metrics:

Metric Name Type Description
http_requests_total counter Total number of requests by method, status, and handler.
http_request_size_bytes summary Content length of incoming requests by handler. Only value of header is respected. Otherwise ignored.
http_response_size_bytes summary Content length of outgoing responses by handler. Only value of header is respected. Otherwise ignored.
http_request_duration_highr_seconds histogram Latency with many buckets but no API specific labels. Made for more accurate percentile calculations.
http_request_duration_seconds histogram Latency with only a few buckets by handler. Made to be only used if aggregation by handler is important.

Should I add it to the PR?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this solution, it saves us from reinventing the wheel in vLLM

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done :)

Copy link
Collaborator

@hmellor hmellor Mar 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, if we're going to be using prometheus-fastapi-instrumentator then this implementation should be removed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I'm not sure I'm following. Which implementation are you referring to that should be removed?
IIUC, prometheus-fastapi-instrumentator simply adds a middleware into FastAPI to collect the metrics specified in the table above. It uses prometheus_client under the hood and adding other vLLM related metrics should be done with prometheus_client.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code highlighted in the original comment

        self.counter_request_success = Counter(
            name="vllm:request_success",
            documentation="Count of successfully processed requests.",
            labelnames=labelnames)

can be removed if we are getting these metrics from prometheus-fastapi-instrumentator instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for clarifying. I believe that vllm:request_success remains valuable. It includes a finished_reason label, which allows for counting requests based on their finished reason — either stop if the sequence ends with an EOS token, or length if the sequence length reaches either scheduler_config.max_model_len or sampling_params.max_tokens. I'm open to adjusting its name and description to make it more indicative.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think of the idea of renaming this to something like vllm:request_info and including n and best_of as labels too? This way we log a single metric from which the user can construct many different visualisations on Grafana by utilising the label filters?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While the idea of combining these metrics may seem appealing at first glance, I believe they should be kept separate for the following reasons:

  1. The metrics have different types: vllm:request_success is a Counter, while vllm:request_params_best_of and vllm:request_params_n are Histograms.

  2. Aggregating different labels lacks semantic meaning.

  3. Although merging n and best_of into the same histograms might make sense in this case, as they would share the same buckets, we may encounter scenarios where we need to introduce another metric with different bucket requirements.

  4. This situation differs from the Info metric type, where data is encoded in the label values.

Comment on lines 100 to 111
self.histogram_max_tokens = Histogram(
name="vllm:request_params_max_tokens",
documentation="Histogram of the max_tokens request parameter.",
labelnames=labelnames,
buckets=build_1_2_5_buckets(max_model_len),
)
self.histogram_request_n = Histogram(
name="vllm:request_params_n",
documentation="Histogram of the n request parameter.",
labelnames=labelnames,
buckets=[1, 2, 5, 10, 20],
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe these should go in an Info, like the cache_config?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, Info is intended to collect constant metrics such as configuration values. vllm:request_params_max_tokens and vllm:request_params_n are intended to collect the values of the max_tokens and n arguments in each request, which may vary with each request.
For instance, in the following request, max_tokens=7 and n=3 will be collected.

curl http://localhost:8000/v1/completions \
    -H "Content-Type: application/json" \
    -d '{
        "model": "facebook/opt-125m",
        "prompt": "San Francisco is a",
        "max_tokens": 7,
        "best_of": 5,
        "n": 3,
        "use_beam_search": "true",
        "temperature": 0
    }'

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point.

Although, I'm not convinced that it's a useful datapoint to collect (I am open to being convinced though!). i.e. what do we, as the vLLM provider, learn from the token limit the user sets?

Also, in the vast majority of use cases a user will statically set their sampling parameters in their application and then never touch them again.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point. Having a histogram of the token limit set by the user might indeed be redundant, particularly considering we have a histogram of the actual number of generated tokens per request (vllm:request_generation_tokens_sum). I'll remove it.

Regarding n and best_of, I realized I overlooked adding a metric for collecting best_of. Just to clarify, best_of determines the width of the beam search, while n specifies how many "top beams" to return (n <= best_of). AFAIU, larger values of best_of can significantly impact the engine, as the batch will be dominated by sequences from a few requests.

It might be insightful to compare the histograms of both n and best_of. A significant deviation between them could suggest that the engine is processing a substantial number of tokens that users aren't actually consuming.

Please let me know your thoughts on this.

Copy link
Collaborator

@hmellor hmellor Mar 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Knowing the amount of tokens "wasted" in beam search could be interesting to know as a developer of vLLM. i.e. if beam search is being used extensively and many tokens are being "wasted", it signals that we need to optimise beam search if we can.

@simon-mo @Yard1 @robertgshaw2-neuralmagic what do you think about this? (not sure who to ping for metrics related things, so please tell me if I should stop)

Comment on lines 55 to 61
# Add prometheus asgi middleware to route /metrics requests
metrics_app = make_asgi_app()
app.mount("/metrics", metrics_app)


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is how the metrics defined in vllm/engine/metrics.py are exposed. It can't be removed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I replaced it with the expose() method of prometheus-fastapi-instrumentator which also exposes a /metric endpoint.
https://github.com/vllm-project/vllm/pull/2764/files#diff-38318677b76349044192bf70161371c88fb2818b85279d8fc7f2c041d83a9544R48-R49

I noticed it also solves the /metrics/ redirection issue.
Which of the exposing methods should we use?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I replaced it with the expose() method of prometheus-fastapi-instrumentator which also exposes a /metric endpoint.

While this does expose a /metrics endpoint, none of the vLLM metrics will be in it because they come from make_asgi_app(), right?

Have you confirmed that /metrics still contains vLLM metrics with this code removed?

Copy link
Contributor Author

@ronensc ronensc Mar 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I've verified that both approaches expose all metrics. The only discrepancy I've noticed is that expose() from prometheus-fastapi-instrumentator exposes metrics on /metrics, whereas make_asgi_app() exposes them on /metrtics/. However, I'll revert to using the make_asgi_app() approach. I find the other method somewhat hacky, as it involves the prometheus-fastapi-instrumentator middleware handling the metrics endpoint. This could look weird if multiple middlewares are in use.

@hmellor hmellor mentioned this pull request Mar 28, 2024
ronensc and others added 3 commits April 20, 2024 13:30
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
@ronensc
Copy link
Contributor Author

ronensc commented Apr 20, 2024

  1. I've incorporated most of the changes from Fixes + More Metrics ronensc/vllm#1 into this PR.
  2. I'm not sure the whether assumption in maybe_get_last_latency() and maybe_set_first_token_time() is correct.
    Both methods are called after self.output_processor.process_outputs() (and specifically after seq_group.update_num_computed_tokens()). By this point, when the first generation token is ready, it is already added to the sequence, so the state of seq_group is changed from PREFILL to DECODE and get_num_uncomputed_tokens() == 1.
    For maybe_set_first_token_time(), I suggest using the condition self.get_seqs()[0].get_output_len() == 1 to determine when the first token is generated.
    As for maybe_get_last_latency(), I suggest using the condition self.is_prefill() to check when chunked_prefill is ongoing.
  3. I modified num_generation_tokens_iter += 1 to num_generation_tokens_iter += seq_group.num_unfinished_seqs() to accommodate requests with more than one sequence (like beam search and parallel sampling).
  4. I'll postpone adding the additional metrics until we get the current set of metrics right.

Note: I applied the changes from #4150 locally to aid in debugging.

@robertgshaw2-neuralmagic
Copy link
Sponsor Collaborator

@ronensc is this ready for review?

@ronensc
Copy link
Contributor Author

ronensc commented Apr 22, 2024

In current state of the PR, some of the metrics are still inaccurate in chunked_prefill.
Before addressing the chunked_prefill issue, could we please merge this PR up to commit 5ded719 (before attempting to solve the chunked_prefill issue)? We can tackle the chunked_prefill problem in a follow-up PR. What do you think?
Also, just a heads-up, I'll be less available in the coming days.

Copy link
Collaborator

@hmellor hmellor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my perspective this LGTM.

Once @robertgshaw2-neuralmagic also approves, I'd say this is good to merge.

@rkooo567 rkooo567 self-assigned this Apr 22, 2024
@njhill
Copy link
Member

njhill commented Apr 25, 2024

Huge thanks for all the work on this and reviews @ronensc @robertgshaw2-neuralmagic @hmellor

@robertgshaw2-neuralmagic
Copy link
Sponsor Collaborator

Im just thinking though

"I'm not sure the whether assumption in maybe_get_last_latency() and maybe_set_first_token_time() is correct.
Both methods are called after self.output_processor.process_outputs() (and specifically after seq_group.update_num_computed_tokens()). By this point, when the first generation token is ready, it is already added to the sequence, so the state of seq_group is changed from PREFILL to DECODE and get_num_uncomputed_tokens() == 1.
For maybe_set_first_token_time(), I suggest using the condition self.get_seqs()[0].get_output_len() == 1 to determine when the first token is generated.
As for maybe_get_last_latency(), I suggest using the condition self.is_prefill() to check when chunked_prefill is ongoing."

Will merge this weekend

@robertgshaw2-neuralmagic
Copy link
Sponsor Collaborator

robertgshaw2-neuralmagic commented Apr 28, 2024

@simon-mo @njhill had to make a couple changes for correctness due to some subtlety with chunked_prefill

Mind giving brief stamp?

@simon-mo simon-mo disabled auto-merge April 28, 2024 22:59
@simon-mo simon-mo merged commit bf480c5 into vllm-project:main Apr 28, 2024
46 of 48 checks passed
robertgshaw2-neuralmagic added a commit to neuralmagic/nm-vllm that referenced this pull request May 6, 2024
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
z103cb pushed a commit to z103cb/opendatahub_vllm that referenced this pull request May 7, 2024
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
Temirulan pushed a commit to Temirulan/vllm-whisper that referenced this pull request Sep 6, 2024
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants