Merge pull request #2 from lightbroker/init

updates
This commit is contained in:
Adam Wilson
2025-04-23 06:57:57 -06:00
committed by GitHub
5 changed files with 105 additions and 23 deletions
+7 -5
View File
@@ -1,4 +1,4 @@
name: REST Server
name: LLM Prompt Testing
on:
# push:
@@ -22,13 +22,15 @@ jobs:
with:
python-version: '3.12'
- name: Download Huggingface CLI
- name: Set up HuggingFace LLM
run: |
pip install huggingface-hub[cli]
huggingface-cli download microsoft/Phi-3-mini-4k-instruct-onnx --include cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/* --local-dir .
pip install onnxruntime-genai
curl https://raw.githubusercontent.com/microsoft/onnxruntime-genai/main/examples/python/phi3-qa.py -o phi3-qa.py
python phi3-qa.py -m cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4 -e cpu -v
# curl https://raw.githubusercontent.com/microsoft/onnxruntime-genai/main/examples/python/phi3-qa.py -o phi3-qa.py
python ${{ github.workspace }}/tests/llm/phi3_language_model.py \
--prompt 'Describe the principle of existence, from the first principles of philosophy.' \
-m cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4 \
- name: Run REST API server
run: |
@@ -36,4 +38,4 @@ jobs:
- name: Test API call
run: |
curl -i localhost:9999/hello
curl -i localhost:9999
+32 -9
View File
@@ -1,22 +1,45 @@
import cgi
import json
class PathDispatcher:
def __init__(self):
self.routes = {}
self.response_headers = [('Content-Type', 'application/json')]
def __http_415_notsupported(self, env, start_response):
start_response('415 Unsupported Media Type', self.response_headers)
return [json.dumps({'error': 'Unsupported Content-Type'}).encode('utf-8')]
def __http_200_ok(self, env, start_response):
try:
request_body_size = int(env.get('CONTENT_LENGTH', 0))
except (ValueError):
request_body_size = 0
request_body = env['wsgi.input'].read(request_body_size)
data = json.loads(request_body.decode('utf-8'))
start_response('200 OK', self.response_headers)
return [json.dumps({'received': data}).encode('utf-8')]
def notfound_404(self, env, start_response):
start_response('404 Not Found', [ ('Content-Type', 'text/plain') ])
return [b'Not Found']
def __call__(self, env, start_response):
method = env.get('REQUEST_METHOD').upper()
path = env.get('PATH_INFO')
params = cgi.FieldStorage(env.get('wsgi.output'), environ=env)
method = env.get('REQUEST_METHOD').lower()
env['params'] = { key: params.getvalue(key) for key in params }
handler = self.routes.get((method,path), self.notfound_404)
return handler(env, start_response)
if not method == 'POST':
self.__http_415_notsupported(env, start_response)
try:
handler = self.routes.get((method,path), self.__http_200_ok)
return handler(env, start_response)
except json.JSONDecodeError:
start_response('400 Bad Request', self.response_headers)
return [json.dumps({'error': 'Invalid JSON'}).encode('utf-8')]
def register(self, method, path, function):
self.routes[method.lower(), path] = function
return function
+10 -8
View File
@@ -1,3 +1,5 @@
import json
from PathDispatcher import PathDispatcher
from wsgiref.simple_server import make_server
@@ -6,19 +8,19 @@ class RestApiServer:
def __init__(self):
pass
def response_function(self, environ, start_response):
start_response('200 OK', [('Content-Type','text/html')])
yield str(f'testing...\n').encode('utf-8')
def post_response(self, env, start_response):
start_response('200 OK', [('Content-Type', 'application/json')])
yield [json.dumps({'received': 'data'}).encode('utf-8')]
def listen(self):
port = 9999
dispatcher = PathDispatcher()
dispatcher.register('GET', '/hello', self.response_function)
wsgi_srv = make_server('', port, dispatcher)
print(f'listening on port {port}...')
wsgi_srv.serve_forever()
dispatcher.register('POST', '/', self.post_response)
with make_server('', port, dispatcher) as wsgi_srv:
print(f'listening on port {port}...')
wsgi_srv.serve_forever()
if __name__ == '__main__':
srv = RestApiServer()
srv.listen()
srv.listen()
-1
View File
@@ -1 +0,0 @@
# TODO: business logic for REST API interaction w/ LLM via prompt input
+56
View File
@@ -0,0 +1,56 @@
# TODO: business logic for REST API interaction w/ LLM via prompt input
import onnxruntime_genai as og
import argparse
class Phi3LanguageModel:
def __init__(self, model_path):
# configure ONNX runtime
config = og.Config(model_path)
config.clear_providers()
self.model = og.Model(config)
self.tokenizer = og.Tokenizer(self.model)
self.tokenizer_stream = self.tokenizer.create_stream()
def get_response(self, prompt_input):
search_options = { 'max_length': 2048 }
params = og.GeneratorParams(self.model)
params.set_search_options(**search_options)
generator = og.Generator(self.model, params)
# process prompt input and generate tokens
chat_template = '<|user|>\n{input} <|end|>\n<|assistant|>'
prompt = f'{chat_template.format(input=prompt_input)}'
input_tokens = self.tokenizer.encode(prompt)
generator.append_tokens(input_tokens)
print("Output: ", end='', flush=True)
try:
while not generator.is_done():
generator.generate_next_token()
new_token = generator.get_next_tokens()[0]
print(self.tokenizer_stream.decode(new_token), end='', flush=True)
except Exception as e:
print(f'{e}')
if __name__ == "__main__":
parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS, description="End-to-end AI Question/Answer example for gen-ai")
parser.add_argument('-m', '--model_path', type=str, required=True, help='Onnx model folder path (must contain genai_config.json and model.onnx)')
parser.add_argument('-p', '--prompt', type=str, required=True, help='Prompt input')
parser.add_argument('-i', '--min_length', type=int, help='Min number of tokens to generate including the prompt')
parser.add_argument('-l', '--max_length', type=int, help='Max number of tokens to generate including the prompt')
parser.add_argument('-ds', '--do_sample', action='store_true', default=False, help='Do random sampling. When false, greedy or beam search are used to generate the output. Defaults to false')
parser.add_argument('--top_p', type=float, help='Top p probability to sample with')
parser.add_argument('--top_k', type=int, help='Top k tokens to sample from')
parser.add_argument('--temperature', type=float, help='Temperature to sample with')
parser.add_argument('--repetition_penalty', type=float, help='Repetition penalty to sample with')
args = parser.parse_args()
model = Phi3LanguageModel(args.model_path)
model.get_response(args.prompt)