claudiubarbu commited on
Commit
76a36b2
1 Parent(s): db72730

added streaming endpoint

Browse files
Files changed (3) hide show
  1. Dockerfile +13 -0
  2. app.py +119 -0
  3. requirements.txt +3 -0
Dockerfile ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9
2
+
3
+ RUN useradd -m -u 1000 user
4
+ USER user
5
+ ENV PATH="/home/user/.local/bin:$PATH"
6
+
7
+ WORKDIR /app
8
+
9
+ COPY --chown=user ./requirements.txt requirements.txt
10
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
11
+
12
+ COPY --chown=user . /app
13
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request
2
+ from fastapi.responses import StreamingResponse
3
+ from pydantic import BaseModel
4
+ from vllm import AsyncLLMEngine, SamplingParams
5
+ from vllm.engine.arg_utils import AsyncEngineArgs
6
+ import json
7
+ import uuid
8
+
9
+ app = FastAPI()
10
+
11
+ # TODO: In the AsyncEngineArgs select the additional parameters
12
+ # to make this deployment efficient. Specifically, consider:
13
+ # - max_num_batched_tokens: Sets the maximum number of tokens that can be processed
14
+ # in a single batch. Make sure to accommodate for the memory constraints of GPU hosting the application.
15
+ # - max_num_seqs: Limits the maximum number of sequences that can
16
+ # be processed concurrently. Smaller numbers will reduce the memory pressure on the GPU.
17
+ # - gpu_memory_utilization: Sets the target GPU memory utilization.
18
+ # Adjust to make more efficient use of available GPU memory.
19
+ # - max_model_len: Specifies the maximum sequence length the model can handle.
20
+ # - enforce_eager: Disables or enables CUDA graph optimization. This can be useful
21
+ # for debugging or when CUDA graph optimization causes issues.
22
+ # - dtype='half': Sets the data type for model parameters to half-precision
23
+ # (float16). This reduces memory usage and can speed up computations, especially on GPUs with good half-precision performance.
24
+ engine = AsyncLLMEngine.from_engine_args(
25
+ AsyncEngineArgs(
26
+ model='claudiubarbu/HW2-orpo',
27
+ max_num_batched_tokens=1024,
28
+ max_num_seqs=8,
29
+ gpu_memory_utilization=0.8,
30
+ max_model_len=512,
31
+ enforce_eager=True,
32
+ dtype='half',
33
+ )
34
+ )
35
+
36
+ class GenerationRequest(BaseModel):
37
+ # FastAPI uses classes like GenerationRequest for several important reasons:
38
+ # - Automatic Request Parsing
39
+ # - Data Validation
40
+ # - Default Values
41
+ # - Self-Documenting APIs
42
+ # - Type Safety in Your Code
43
+ prompt: str
44
+ max_tokens: int = 100
45
+ temperature: float = 0.7
46
+
47
+
48
+ async def generate_stream(prompt: str, max_tokens: int, temperature: float):
49
+ """
50
+ The function generate_stream is an asynchronous generator that produces a stream of
51
+ text from a language model. Asynchronous functions can pause their execution,
52
+ allowing other code to run while waiting for operations to complete.
53
+
54
+ prompt: The initial text to start the generation.
55
+ max_tokens: The maximum number of tokens (words or word pieces) to generate.
56
+ temperature: Controls the randomness of the generation. Higher values (e.g., 1.0)
57
+ make output more random, while lower values (e.g., 0.1) make it more deterministic.
58
+ """
59
+
60
+ # SamplingParams configures how the text generation will behave.
61
+ # It uses the temperature and max_tokens values passed to the function.
62
+ sampling_params = SamplingParams(
63
+ temperature=temperature,
64
+ max_tokens=max_tokens
65
+ )
66
+
67
+ # The request_id is used by vLLM to track different generation requests,
68
+ # especially useful in scenarios with multiple concurrent requests.
69
+ # Using a UUID ensures that each request has a unique identifier,
70
+ # preventing conflicts between different generation tasks.
71
+ request_id = str(uuid.uuid4())
72
+
73
+ # async for is an asynchronous loop that works with asynchronous generators.
74
+ # engine.generate() is an instance of the language model that generates text
75
+ # based on the given prompt and parameters. The loop will receive chunks of
76
+ # generated text one at a time rather than waiting for the entire text to be generated.
77
+ # The generate function requires a request_id, which I set to 1
78
+ async for output in engine.generate(prompt, sampling_params, request_id=request_id):
79
+ # yield is used in generator functions to produce a series of values
80
+ # over time rather than computing them all at once. The yielded string
81
+ # follows the Server-Sent Events (SSE) format:
82
+ # - It starts with "data: ".
83
+ # - The content is a JSON string containing the generated text.
84
+ # - It ends with two newlines (\n\n) to signal the end of an SSE message.
85
+ yield f"data: {json.dumps({'text': output.outputs[0].text})}\n\n"
86
+
87
+ # After the generation is complete, we yield a special "DONE" signal,
88
+ # also in SSE format, to indicate that the stream has ended.
89
+ yield "data: [DONE]\n\n"
90
+
91
+
92
+ # This line tells FastAPI that this function should handle POST requests
93
+ # to the "/generate-stream" endpoint.
94
+ @app.post("/generate-stream")
95
+ async def generate_text(request: GenerationRequest):
96
+ """
97
+ The function generate_text is a FastAPI route that handles POST requests to "/generate-stream".
98
+ It's designed to stream generated text back to the client as it's being produced
99
+ rather than waiting for all the text to be generated before sending a response.
100
+ """
101
+ try:
102
+ # StreamingResponse is used to send a streaming response back to the client.
103
+ # generate_stream() is called with the parameters from the request. This function is expected to be a generator that yields chunks of text.
104
+ # media_type="text/event-stream" indicates that this is a Server-Sent Events (SSE) stream, a format for sending real-time updates from server to client.
105
+ return StreamingResponse(
106
+ generate_stream(request.prompt, request.max_tokens, request.temperature),
107
+ media_type="text/event-stream"
108
+ )
109
+ except Exception as e:
110
+ # If an exception occurs, it returns a streaming response with an error message,
111
+ # maintaining the SSE format.
112
+ return StreamingResponse(
113
+ iter([f"data: {json.dumps({'error': str(e)})}\n\n"]),
114
+ media_type="text/event-stream"
115
+ )
116
+
117
+ @app.get("/")
118
+ def greet_json():
119
+ return {"Hello": "World!"}
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ vllm