Skip to content

Commit cf4dd92

Browse files
gustavocidornelaswhoseoyster
authored andcommitted
Allow validation sets without model outputs for LLMs (script and api)
1 parent fc78915 commit cf4dd92

File tree

3 files changed

+45
-3
lines changed

3 files changed

+45
-3
lines changed

openlayer/__init__.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import tempfile
2828
import time
2929
import uuid
30+
import warnings
3031
from typing import Optional
3132

3233
import pandas as pd
@@ -937,15 +938,47 @@ def commit(self, message: str, project_id: str, force: bool = False):
937938
print("Keeping the existing commit message.")
938939
return
939940

941+
llm_and_no_outputs = self._check_llm_and_no_outputs(project_dir=project_dir)
942+
if llm_and_no_outputs:
943+
warnings.warn(
944+
"You are committing an LLM without validation outputs computed "
945+
"in the validation set. This means that the platform will try to "
946+
"compute the validation outputs for you. This may take a while and "
947+
"there are costs associated with it."
948+
)
940949
commit = {
941950
"message": message,
942951
"date": time.ctime(),
952+
"computeOutputs": llm_and_no_outputs,
943953
}
944954
with open(f"{project_dir}/commit.yaml", "w", encoding="UTF-8") as commit_file:
945955
yaml.dump(commit, commit_file)
946956

947957
print("Committed!")
948958

959+
def _check_llm_and_no_outputs(self, project_dir: str) -> bool:
960+
"""Checks if the project's staging area contains an LLM and no outputs."""
961+
# Check if validation set has outputs
962+
validation_has_no_outputs = False
963+
if os.path.exists(f"{project_dir}/validation"):
964+
validation_dataset_config = utils.load_dataset_config_from_bundle(
965+
bundle_path=project_dir, label="validation"
966+
)
967+
output_column_name = validation_dataset_config.get("outputColumnName")
968+
validation_has_no_outputs = output_column_name is None
969+
970+
# Check if the model is an LLM
971+
model_is_llm = False
972+
if os.path.exists(f"{project_dir}/model"):
973+
model_config = utils.read_yaml(f"{project_dir}/model/model_config.yaml")
974+
architecture_type = model_config.get("architectureType")
975+
model_type = model_config.get("modelType")
976+
977+
if architecture_type == "llm" and model_type != "shell":
978+
model_is_llm = True
979+
980+
return validation_has_no_outputs and model_is_llm
981+
949982
def push(self, project_id: str, task_type: TaskType) -> Optional[ProjectVersion]:
950983
"""Pushes the commited resources to the platform.
951984

openlayer/schemas.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,8 @@ class LLMOutputSchema(BaseDatasetSchema):
153153
)
154154
outputColumnName = ma.fields.Str(
155155
validate=COLUMN_NAME_VALIDATION_LIST,
156-
required=True,
156+
allow_none=True,
157+
load_default=None,
157158
)
158159

159160

openlayer/validators/commit_validators.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import marshmallow as ma
88
import pandas as pd
9+
import yaml
910

1011
from .. import schemas, tasks, utils
1112
from . import baseline_model_validators, dataset_validators, model_validators
@@ -127,6 +128,11 @@ def _validate_bundle_state(self):
127128
label="validation"
128129
)
129130

131+
# Check if flagged to compute the model outputs
132+
with open(f"{self.bundle_path}/commit.yaml", "r") as commit_file:
133+
commit = yaml.safe_load(commit_file)
134+
compute_outputs = commit.get("computeOutputs", False)
135+
130136
if "model" in self._bundle_resources:
131137
model_type = self.model_config.get("modelType")
132138

@@ -163,7 +169,7 @@ def _validate_bundle_state(self):
163169
"training" not in self._bundle_resources
164170
or "fine-tuning" not in self._bundle_resources
165171
) and ("validation" in self._bundle_resources):
166-
if not outputs_in_validation_set:
172+
if not outputs_in_validation_set and not compute_outputs:
167173
self.failed_validations.append(
168174
"You are trying to push a model and a validation set to the platform. "
169175
"However, the validation set does not contain predictions. "
@@ -186,7 +192,9 @@ def _validate_bundle_state(self):
186192
"training" in self._bundle_resources
187193
or "fine-tuning" in self._bundle_resources
188194
) and ("validation" in self._bundle_resources):
189-
if not outputs_in_training_set or not outputs_in_validation_set:
195+
if (
196+
not outputs_in_training_set or not outputs_in_validation_set
197+
) and not compute_outputs:
190198
self.failed_validations.append(
191199
"You are trying to push a model, a training/fine-tuning set and a validation "
192200
"set to the platform. "

0 commit comments

Comments
 (0)