Skip to content

Commit d2c762c

Browse files
gustavocidornelaswhoseoyster
authored andcommitted
Fix part of pylint issues
1 parent cf4dd92 commit d2c762c

File tree

5 files changed

+42
-32
lines changed

5 files changed

+42
-32
lines changed

openlayer/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
"""Module for storing constants used throughout the OpenLayer Python Client.
2+
"""
13
import os
24

35
# ---------------------------- Commit/staging flow --------------------------- #

openlayer/model_runners/ll_model_runners.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
"""
55

66
import datetime
7-
import json
87
import logging
98
import warnings
109
from abc import ABC, abstractmethod
@@ -59,32 +58,33 @@ def run(
5958
"""Runs the input data through the model."""
6059
if self.in_memory:
6160
return self._run_in_memory(
62-
input_data_df=input_data,
61+
input_data=input_data,
6362
output_column_name=output_column_name,
6463
)
6564
else:
6665
return self._run_in_conda(
67-
input_data_df=input_data, output_column_name=output_column_name
66+
input_data=input_data, output_column_name=output_column_name
6867
)
6968

7069
def _run_in_memory(
7170
self,
72-
input_data_df: pd.DataFrame,
71+
input_data: pd.DataFrame,
7372
output_column_name: Optional[str] = None,
7473
) -> pd.DataFrame:
7574
"""Runs the input data through the model in memory and returns a pandas
7675
dataframe."""
7776
for output_df, _ in tqdm(
78-
self._run_in_memory_and_yield_progress(input_data_df, output_column_name),
79-
total=len(input_data_df),
77+
self._run_in_memory_and_yield_progress(input_data, output_column_name),
78+
total=len(input_data),
8079
colour="BLUE",
8180
):
8281
pass
82+
# pylint: disable=undefined-loop-variable
8383
return output_df
8484

8585
def _run_in_memory_and_yield_progress(
8686
self,
87-
input_data_df: pd.DataFrame,
87+
input_data: pd.DataFrame,
8888
output_column_name: Optional[str] = None,
8989
) -> Generator[Tuple[pd.DataFrame, float], None, None]:
9090
"""Runs the input data through the model in memory and yields the results
@@ -95,10 +95,10 @@ def _run_in_memory_and_yield_progress(
9595
timestamps = []
9696
run_exceptions = set()
9797
run_cost = 0
98-
total_rows = len(input_data_df)
98+
total_rows = len(input_data)
9999
current_row = 0
100100

101-
for _, input_data_row in input_data_df.iterrows():
101+
for _, input_data_row in input_data.iterrows():
102102
# Check if output column already has a value to avoid re-running
103103
if output_column_name and output_column_name in input_data_row:
104104
output_value = input_data_row[output_column_name]
@@ -149,6 +149,7 @@ def _run_single_input(self, input_data_row: pd.Series) -> Tuple[str, float, set]
149149
try:
150150
outputs = self._get_llm_output(llm_input)
151151
return outputs["output"], outputs["cost"], set()
152+
# pylint: disable=broad-except
152153
except Exception as exc:
153154
return None, 0, {exc}
154155

@@ -223,7 +224,7 @@ def _report_exceptions(self, exceptions: set) -> None:
223224
)
224225

225226
def _run_in_conda(
226-
self, input_data_df: pd.DataFrame, output_column_name: Optional[str] = None
227+
self, input_data: pd.DataFrame, output_column_name: Optional[str] = None
227228
) -> pd.DataFrame:
228229
"""Runs LLM prediction job in a conda environment."""
229230
raise NotImplementedError(
@@ -253,7 +254,7 @@ def run_and_yield_progress(
253254
"""Runs the input data through the model and yields progress."""
254255
if self.in_memory:
255256
yield from self._run_in_memory_and_yield_progress(
256-
input_data_df=input_data,
257+
input_data=input_data,
257258
output_column_name=output_column_name,
258259
)
259260
else:
@@ -376,7 +377,7 @@ def _initialize_llm(self):
376377
raise ValueError(
377378
"Cohere API key is invalid. Please pass a valid API key as the "
378379
f"keyword argument 'cohere_api_key' \n Error message: {e}"
379-
)
380+
) from e
380381
if self.model_config.get("model") is None:
381382
warnings.warn("No model specified. Defaulting to model 'command'.")
382383
if self.model_config.get("model_parameters") is None:
@@ -461,7 +462,7 @@ def _initialize_llm(self):
461462
raise ValueError(
462463
"OpenAI API key is invalid. Please pass a valid API key as the "
463464
f"keyword argument 'openai_api_key' \n Error message: {e}"
464-
)
465+
) from e
465466
if self.model_config.get("model") is None:
466467
warnings.warn("No model specified. Defaulting to model 'gpt-3.5-turbo'.")
467468
if self.model_config.get("model_parameters") is None:
@@ -539,12 +540,13 @@ def _initialize_llm(self):
539540
"""Initializes the self-hosted LL model."""
540541
# Check if API key is valid
541542
try:
542-
requests.get(self.url)
543+
# TODO: move request timeout to constants.py
544+
requests.get(self.url, timeout=10800)
543545
except Exception as e:
544546
raise ValueError(
545547
"URL is invalid. Please pass a valid URL as the "
546548
f"keyword argument 'url' \n Error message: {e}"
547-
)
549+
) from e
548550

549551
def _get_llm_input(self, injected_prompt: List[Dict[str, str]]) -> str:
550552
"""Prepares the input for the self-hosted LLM."""
@@ -572,7 +574,8 @@ def _make_request(self, llm_input: str) -> Dict[str, Any]:
572574
"Content-Type": "application/json",
573575
}
574576
data = {self.input_key: llm_input}
575-
response = requests.post(self.url, headers=headers, json=data)
577+
# TODO: move request timeout to constants.py
578+
response = requests.post(self.url, headers=headers, json=data, timeout=10800)
576579
if response.status_code == 200:
577580
response_data = response.json()[0]
578581
return response_data
@@ -592,4 +595,6 @@ class HuggingFaceModelRunner(SelfHostedLLModelRunner):
592595
"""Wraps LLMs hosted in HuggingFace."""
593596

594597
def __init__(self, url, api_key):
595-
super().__init__(url, api_key, input_key="inputs", output_key="generated_text")
598+
super().__init__(
599+
url=url, ali_key=api_key, input_key="inputs", output_key="generated_text"
600+
)

openlayer/model_runners/tests/test_llm_runners.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,26 +54,30 @@
5454
)
5555

5656
# ----------------------------- Expected results ----------------------------- #
57+
# flake8: noqa: E501
5758
OPENAI_PROMPT = [
5859
*PROMPT[:-1],
5960
{
6061
"role": "user",
6162
"content": """description: A smartwatch with fitness tracking capabilities \n\nseed words: smart, fitness, health""",
6263
},
6364
]
64-
COHERE_PROMPT = """S: You are a helpful assistant.
65+
66+
# flake8: noqa: E501
67+
COHERE_PROMPT = """S: You are a helpful assistant.
6568
U: You will be provided with a product description and seed words, and your task is to generate a list
6669
of product names and provide a short description of the target customer for such product. The output
67-
must be a valid JSON with attributes `names` and `target_custommer`.
68-
A: Let\'s get started!
69-
U: Product description: \n description: A home milkshake maker \n seed words: fast, healthy, compact
70-
A: {\n "names": ["QuickBlend", "FitShake", "MiniMix"]\n "target_custommer": "College students that are into fitness and healthy living"\n}
71-
U: description: A smartwatch with fitness tracking capabilities \n\nseed words: smart, fitness, health
70+
must be a valid JSON with attributes `names` and `target_custommer`.
71+
A: Let\'s get started!
72+
U: Product description: \n description: A home milkshake maker \n seed words: fast, healthy, compact
73+
A: {\n "names": ["QuickBlend", "FitShake", "MiniMix"]\n "target_custommer": "College students that are into fitness and healthy living"\n}
74+
U: description: A smartwatch with fitness tracking capabilities \n\nseed words: smart, fitness, health
7275
A:"""
7376

77+
# flake8: noqa: E501
7478
ANTHROPIC_PROMPT = f"""{anthropic.HUMAN_PROMPT} You are a helpful assistant. {anthropic.HUMAN_PROMPT} You will be provided with a product description and seed words, and your task is to generate a list
7579
of product names and provide a short description of the target customer for such product. The output
76-
must be a valid JSON with attributes `names` and `target_custommer`. {anthropic.AI_PROMPT} Let\'s get started! {anthropic.HUMAN_PROMPT} Product description:
80+
must be a valid JSON with attributes `names` and `target_custommer`. {anthropic.AI_PROMPT} Let\'s get started! {anthropic.HUMAN_PROMPT} Product description:
7781
description: A home milkshake maker \n seed words: fast, healthy, compact {anthropic.AI_PROMPT} {{\n "names": ["QuickBlend", "FitShake", "MiniMix"]\n "target_custommer": "College students that are into fitness and healthy living"\n}} {anthropic.HUMAN_PROMPT} description: A smartwatch with fitness tracking capabilities \n\nseed words: smart, fitness, health {anthropic.AI_PROMPT}"""
7882

7983

openlayer/validators/commit_validators.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,9 @@ def _validate_bundle_state(self):
129129
)
130130

131131
# Check if flagged to compute the model outputs
132-
with open(f"{self.bundle_path}/commit.yaml", "r") as commit_file:
132+
with open(
133+
f"{self.bundle_path}/commit.yaml", "r", encoding="UTF-8"
134+
) as commit_file:
133135
commit = yaml.safe_load(commit_file)
134136
compute_outputs = commit.get("computeOutputs", False)
135137

@@ -262,7 +264,7 @@ def _validate_bundle_resources(self):
262264
if "model" in self._bundle_resources and not self._skip_model_validation:
263265
model_config_file_path = f"{self.bundle_path}/model/model_config.yaml"
264266
model_type = self.model_config.get("modelType")
265-
if model_type == "shell" or model_type == "api":
267+
if model_type in ("shell", "api"):
266268
model_validator = model_validators.get_validator(
267269
task_type=self.task_type,
268270
model_config_file_path=model_config_file_path,

openlayer/validators/model_validators.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -619,15 +619,12 @@ def get_validator(
619619

620620

621621
# --------------- Helper functions used by multiple validators --------------- #
622-
def dir_exceeds_size_limit(dir: str) -> bool:
622+
def dir_exceeds_size_limit(dir_path: str) -> bool:
623623
"""Checks whether the tar version of the directory exceeds the maximim limit."""
624624
with tempfile.TemporaryDirectory() as tmp_dir:
625625
tar_file_path = os.path.join(tmp_dir, "tarfile")
626626
with tarfile.open(tar_file_path, mode="w:gz") as tar:
627-
tar.add(dir, arcname=os.path.basename(dir))
627+
tar.add(dir_path, arcname=os.path.basename(dir_path))
628628
tar_file_size = os.path.getsize(tar_file_path)
629629

630-
if tar_file_size > constants.MAXIMUM_TAR_FILE_SIZE * 1024 * 1024:
631-
return True
632-
else:
633-
return False
630+
return tar_file_size > constants.MAXIMUM_TAR_FILE_SIZE * 1024 * 1024

0 commit comments

Comments
 (0)