diff --git a/llm/predict/predictor.py b/llm/predict/predictor.py index d02d896e3c3d..81e9b8bb9f49 100644 --- a/llm/predict/predictor.py +++ b/llm/predict/predictor.py @@ -347,7 +347,10 @@ def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer = N super().__init__(config, tokenizer) params_path = os.path.join(self.config.model_name_or_path, self.config.model_prefix + ".pdiparams") - model_path = os.path.join(self.config.model_name_or_path, self.config.model_prefix + ".pdmodel") + if paddle.framework.use_pir_api(): + model_path = os.path.join(self.config.model_name_or_path, self.config.model_prefix + ".json") + else: + model_path = os.path.join(self.config.model_name_or_path, self.config.model_prefix + ".pdmodel") inference_config = paddle.inference.Config(model_path, params_path) if self.config.device == "gpu":