diff --git a/tests/llm/phi3_language_model.py b/tests/llm/phi3_language_model.py index 542a396c6..7c6ab73b0 100644 --- a/tests/llm/phi3_language_model.py +++ b/tests/llm/phi3_language_model.py @@ -7,20 +7,20 @@ import argparse class Phi3LanguageModel: def __init__(self, model_path): - # configure the ONNX runtime + # configure ONNX runtime config = og.Config(model_path) config.clear_providers() self.model = og.Model(config) - self.tokenizer = og.Tokenizer(model) + self.tokenizer = og.Tokenizer(self.model) self.tokenizer_stream = self.tokenizer.create_stream() def get_response(self, args): search_options = { 'max_length': 2048 } - params = og.GeneratorParams(model) + params = og.GeneratorParams(self.model) params.set_search_options(**search_options) - generator = og.Generator(model, params) + generator = og.Generator(self.model, params) # process prompt input and generate tokens prompt_input = args.prompt