diff --git a/src/text_generation/adapters/llm/text_generation_foundation_model.py b/src/text_generation/adapters/llm/text_generation_foundation_model.py index f6e451336..cbd8cf733 100644 --- a/src/text_generation/adapters/llm/text_generation_foundation_model.py +++ b/src/text_generation/adapters/llm/text_generation_foundation_model.py @@ -47,16 +47,22 @@ class TextGenerationFoundationModel: # Create the text generation pipeline pipe = pipeline( "text-generation", + do_sample=True, + max_new_tokens=512, model=model, - tokenizer=tokenizer, - max_new_tokens=256, - temperature=0.7, - top_p=0.9, repetition_penalty=1.1, + temperature=0.3, + tokenizer=tokenizer, use_fast=True, - do_sample=True + pad_token_id=tokenizer.eos_token_id, + eos_token_id=tokenizer.eos_token_id, ) # Create the LangChain LLM - return HuggingFacePipeline(pipeline=pipe) + return HuggingFacePipeline( + pipeline=pipe, + pipeline_kwargs={ + "return_full_text": False, + "stop_sequence": ["<|end|>", "<|user|>", ""] + })