Skip to content

Commit c9e8cf8

Browse files
committed
Fixes bug with text classification full models
1 parent 6a44389 commit c9e8cf8

File tree

4 files changed

+8
-6
lines changed

4 files changed

+8
-6
lines changed

examples/text-classification/sklearn/banking/demo-banking.ipynb

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -585,7 +585,8 @@
585585
"\n",
586586
" def predict_proba(self, input_data_df: pd.DataFrame):\n",
587587
" \"\"\"Makes predictions with the model. Returns the class probabilities.\"\"\"\n",
588-
" return self.model.predict_proba(input_data_df)\n",
588+
" text_column = input_data_df.columns[0]\n",
589+
" return self.model.predict_proba(input_data_df[text_column])\n",
589590
"\n",
590591
"\n",
591592
"def load_model():\n",
@@ -750,7 +751,7 @@
750751
"name": "python",
751752
"nbconvert_exporter": "python",
752753
"pygments_lexer": "ipython3",
753-
"version": "3.8.13"
754+
"version": "3.8.10"
754755
}
755756
},
756757
"nbformat": 4,

examples/text-classification/sklearn/sentiment-analysis/sentiment-sklearn.ipynb

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -596,7 +596,8 @@
596596
"\n",
597597
" def predict_proba(self, input_data_df: pd.DataFrame):\n",
598598
" \"\"\"Makes predictions with the model. Returns the class probabilities.\"\"\"\n",
599-
" return self.model.predict_proba(input_data_df)\n",
599+
" text_column = input_data_df.columns[0]\n",
600+
" return self.model.predict_proba(input_data_df[text_column])\n",
600601
"\n",
601602
"\n",
602603
"def load_model():\n",
@@ -761,7 +762,7 @@
761762
"name": "python",
762763
"nbconvert_exporter": "python",
763764
"pygments_lexer": "ipython3",
764-
"version": "3.8.13"
765+
"version": "3.8.10"
765766
}
766767
},
767768
"nbformat": 4,

openlayer/validators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def _validate_bundle_resources(self):
315315
sample_data = None
316316
if "textColumnName" in validation_dataset_config:
317317
sample_data = validation_dataset_df[
318-
validation_dataset_config["textColumnName"]
318+
[validation_dataset_config["textColumnName"]]
319319
].head()
320320

321321
else:

openlayer/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# Define the SDK version here so that the interal package can have access to this value.
22
# See https://stackoverflow.com/questions/2058802
3-
__version__ = "0.0.0a8"
3+
__version__ = "0.0.0a9"

0 commit comments

Comments
 (0)