Skip to content

Commit 9167c1d

Browse files
committed
improvement: allow specifying dataset as path for uploads
1 parent 60fd440 commit 9167c1d

File tree

1 file changed

+17
-4
lines changed

1 file changed

+17
-4
lines changed

src/openlayer/lib/data/batch_inferences.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,18 @@
1919
def upload_batch_inferences(
2020
client: Openlayer,
2121
inference_pipeline_id: str,
22-
dataset_df: pd.DataFrame,
2322
config: data_stream_params.Config,
23+
dataset_df: Optional[pd.DataFrame] = None,
24+
dataset_path: Optional[str] = None,
2425
storage_type: Optional[StorageType] = None,
2526
merge: bool = False,
2627
) -> None:
2728
"""Uploads a batch of inferences to the Openlayer platform."""
29+
if dataset_df is None and dataset_path is None:
30+
raise ValueError("Either dataset_df or dataset_path must be provided.")
31+
if dataset_df is not None and dataset_path is not None:
32+
raise ValueError("Only one of dataset_df or dataset_path should be provided.")
33+
2834
uploader = _upload.Uploader(client, storage_type)
2935
object_name = f"batch_data_{time.time()}_{inference_pipeline_id}.tar.gz"
3036

@@ -35,8 +41,11 @@ def upload_batch_inferences(
3541

3642
# Write dataset and config to temp directory
3743
with tempfile.TemporaryDirectory() as tmp_dir:
38-
temp_file_path = f"{tmp_dir}/dataset.csv"
39-
dataset_df.to_csv(temp_file_path, index=False)
44+
if dataset_df is not None:
45+
temp_file_path = f"{tmp_dir}/dataset.csv"
46+
dataset_df.to_csv(temp_file_path, index=False)
47+
else:
48+
temp_file_path = dataset_path
4049

4150
# Copy relevant files to tmp dir
4251
config["label"] = "production"
@@ -47,7 +56,11 @@ def upload_batch_inferences(
4756

4857
tar_file_path = os.path.join(tmp_dir, object_name)
4958
with tarfile.open(tar_file_path, mode="w:gz") as tar:
50-
tar.add(tmp_dir, arcname=os.path.basename("monitoring_data"))
59+
tar.add(temp_file_path, arcname=os.path.basename("dataset.csv"))
60+
tar.add(
61+
f"{tmp_dir}/dataset_config.yaml",
62+
arcname=os.path.basename("dataset_config.yaml"),
63+
)
5164

5265
# Upload to storage
5366
uploader.upload(

0 commit comments

Comments
 (0)