custom-vllm / app.py
claudiubarbu's picture
added streaming endpoint
76a36b2
raw
history blame
No virus
5.73 kB
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from vllm import AsyncLLMEngine, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
import json
import uuid
app = FastAPI()
# TODO: In the AsyncEngineArgs select the additional parameters
# to make this deployment efficient. Specifically, consider:
# - max_num_batched_tokens: Sets the maximum number of tokens that can be processed
# in a single batch. Make sure to accommodate for the memory constraints of GPU hosting the application.
# - max_num_seqs: Limits the maximum number of sequences that can
# be processed concurrently. Smaller numbers will reduce the memory pressure on the GPU.
# - gpu_memory_utilization: Sets the target GPU memory utilization.
# Adjust to make more efficient use of available GPU memory.
# - max_model_len: Specifies the maximum sequence length the model can handle.
# - enforce_eager: Disables or enables CUDA graph optimization. This can be useful
# for debugging or when CUDA graph optimization causes issues.
# - dtype='half': Sets the data type for model parameters to half-precision
# (float16). This reduces memory usage and can speed up computations, especially on GPUs with good half-precision performance.
engine = AsyncLLMEngine.from_engine_args(
AsyncEngineArgs(
model='claudiubarbu/HW2-orpo',
max_num_batched_tokens=1024,
max_num_seqs=8,
gpu_memory_utilization=0.8,
max_model_len=512,
enforce_eager=True,
dtype='half',
)
)
class GenerationRequest(BaseModel):
# FastAPI uses classes like GenerationRequest for several important reasons:
# - Automatic Request Parsing
# - Data Validation
# - Default Values
# - Self-Documenting APIs
# - Type Safety in Your Code
prompt: str
max_tokens: int = 100
temperature: float = 0.7
async def generate_stream(prompt: str, max_tokens: int, temperature: float):
"""
The function generate_stream is an asynchronous generator that produces a stream of
text from a language model. Asynchronous functions can pause their execution,
allowing other code to run while waiting for operations to complete.
prompt: The initial text to start the generation.
max_tokens: The maximum number of tokens (words or word pieces) to generate.
temperature: Controls the randomness of the generation. Higher values (e.g., 1.0)
make output more random, while lower values (e.g., 0.1) make it more deterministic.
"""
# SamplingParams configures how the text generation will behave.
# It uses the temperature and max_tokens values passed to the function.
sampling_params = SamplingParams(
temperature=temperature,
max_tokens=max_tokens
)
# The request_id is used by vLLM to track different generation requests,
# especially useful in scenarios with multiple concurrent requests.
# Using a UUID ensures that each request has a unique identifier,
# preventing conflicts between different generation tasks.
request_id = str(uuid.uuid4())
# async for is an asynchronous loop that works with asynchronous generators.
# engine.generate() is an instance of the language model that generates text
# based on the given prompt and parameters. The loop will receive chunks of
# generated text one at a time rather than waiting for the entire text to be generated.
# The generate function requires a request_id, which I set to 1
async for output in engine.generate(prompt, sampling_params, request_id=request_id):
# yield is used in generator functions to produce a series of values
# over time rather than computing them all at once. The yielded string
# follows the Server-Sent Events (SSE) format:
# - It starts with "data: ".
# - The content is a JSON string containing the generated text.
# - It ends with two newlines (\n\n) to signal the end of an SSE message.
yield f"data: {json.dumps({'text': output.outputs[0].text})}\n\n"
# After the generation is complete, we yield a special "DONE" signal,
# also in SSE format, to indicate that the stream has ended.
yield "data: [DONE]\n\n"
# This line tells FastAPI that this function should handle POST requests
# to the "/generate-stream" endpoint.
@app.post("/generate-stream")
async def generate_text(request: GenerationRequest):
"""
The function generate_text is a FastAPI route that handles POST requests to "/generate-stream".
It's designed to stream generated text back to the client as it's being produced
rather than waiting for all the text to be generated before sending a response.
"""
try:
# StreamingResponse is used to send a streaming response back to the client.
# generate_stream() is called with the parameters from the request. This function is expected to be a generator that yields chunks of text.
# media_type="text/event-stream" indicates that this is a Server-Sent Events (SSE) stream, a format for sending real-time updates from server to client.
return StreamingResponse(
generate_stream(request.prompt, request.max_tokens, request.temperature),
media_type="text/event-stream"
)
except Exception as e:
# If an exception occurs, it returns a streaming response with an error message,
# maintaining the SSE format.
return StreamingResponse(
iter([f"data: {json.dumps({'error': str(e)})}\n\n"]),
media_type="text/event-stream"
)
@app.get("/")
def greet_json():
return {"Hello": "World!"}