配置环境

到vllm的GitHub仓库GitHub - vllm-project/vllm: A high-throughput and memory-efficient inference and serving engine for LLMs按照说明安装就可以了,不赘述。

运行接口

OpenAI 的 API Server

python -m vllm.entrypoints.openai.api_server --model your_model_path --trust-remote-code

运行下面的命令,默认host为0.0.0.0,默认端口为8000,也可以通过--host --port指定。使用chatglm等模型时,请指定 --trust-remote-code参数。

curl http://10.102.33.181:8000/v1/completions \
    -H "Content-Type: application/json" \
    -d '{
        "model": your_model_path,
        "prompt": "San Francisco is a",
        "max_tokens": 7,
        "temperature": 0
    }'

调用时可以用下面测试,注意model参数一定要传

其他更多的参数请参照https://github.com/vllm-project/vllm/blob/9b294976a2373f6fda22c1b2e704c395c8bd0787/vllm/entrypoints/openai/api_server.py#L252中的sampling_params

sampling_params = SamplingParams(
            n=request.n,
            presence_penalty=request.presence_penalty,
            frequency_penalty=request.frequency_penalty,
            temperature=request.temperature,
            top_p=request.top_p,
            stop=request.stop,
            stop_token_ids=request.stop_token_ids,
            max_tokens=request.max_tokens,
            best_of=request.best_of,
            top_k=request.top_k,
            ignore_eos=request.ignore_eos,
            use_beam_search=request.use_beam_search,
            skip_special_tokens=request.skip_special_tokens,
            spaces_between_special_tokens=spaces_between_special_tokens,
        )

具体参数的含义请参照https://github.com/vllm-project/vllm/blob/9b294976a2373f6fda22c1b2e704c395c8bd0787/vllm/sampling_params.pySamplingParams 类中的说明。

使用程序调用

import requests
import json

def get_response(text):
    raw_json_data = {
        "model": your_model_path,
        "prompt": prompt,
        "logprobs":1,
        "max_tokens": 100,
        "temperature": 0
    }
    json_data = json.dumps(raw_json_data)
    headers = {
        "Content-Type": "application/json",
        "User-Agent": "PostmanRuntime/7.29.2",
        "Accept": "*/*",
        "Accept-Encoding": "gzip, deflate, br",
        "Connection": "keep-alive"
    }
    response = requests.post(f'http://localhost:8000/v1/completions',
                             data=json_data,
                             headers=headers)
    if response.status_code == 200:
        response = json.loads(response.text)
        response_data = response["choices"][0]['text']
    else:
        print(data)
    return response_data

使用grequests实现异步批请求

import json
import time
import grequests


headers = {'Content-Type': 'application/json'}
data = {
    "model": your_model_path,
    "prompt": prompt,
    "logprobs":1,
    "max_tokens": 100,
    "temperature": 0
}
start = time.time()
req_list = [   # 请求列表
    grequests.post('http://localhost:8000/v1/completions', data=json.dumps(data), headers=headers) for i in range(10)
]
res_list = grequests.map(req_list)
print(round(time.time()-start, 1))
print(json.loads(res_list[0].text)["choices"][0]['text']) 
Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐