asoria HF staff commited on
Commit
713d673
β€’
1 Parent(s): 0b212ec

Adding dataset type validation

Browse files
Files changed (4) hide show
  1. app.py +10 -12
  2. notebooks/eda.json +1 -1
  3. notebooks/embeddings.json +1 -1
  4. notebooks/rag.json +1 -1
app.py CHANGED
@@ -125,6 +125,8 @@ def generate_cells(dataset_id, notebook_title):
125
  logging.info(f"Generating {notebook_title} notebook for dataset {dataset_id}")
126
  cells = notebook_templates[notebook_title]["notebook_template"]
127
  notebook_type = notebook_templates[notebook_title]["notebook_type"]
 
 
128
  try:
129
  libraries = get_compatible_libraries(dataset_id)
130
  except Exception as err:
@@ -155,22 +157,18 @@ def generate_cells(dataset_id, notebook_title):
155
  has_numeric_columns = len(df.select_dtypes(include=["number"]).columns) > 0
156
  has_categoric_columns = len(df.select_dtypes(include=["object"]).columns) > 0
157
 
158
- # TODO: Validate by notebook type
159
- if notebook_type in ("rag", "embeddings") and not has_categoric_columns:
160
- logging.error(
161
- "Dataset does not have categorical columns, which are required for RAG generation."
162
- )
163
- return (
164
- "",
165
- "## ❌ This dataset does not have categorical columns, which are required for Embeddings/RAG generation ❌",
166
- )
167
- if notebook_type == "eda" and not (has_categoric_columns or has_numeric_columns):
168
  logging.error(
169
- "Dataset does not have categorical or numeric columns, which are required for EDA generation."
170
  )
171
  return (
172
  "",
173
- "## ❌ This dataset does not have categorical or numeric columns, which are required for EDA generation ❌",
174
  )
175
 
176
  cells = replace_wildcards(
 
125
  logging.info(f"Generating {notebook_title} notebook for dataset {dataset_id}")
126
  cells = notebook_templates[notebook_title]["notebook_template"]
127
  notebook_type = notebook_templates[notebook_title]["notebook_type"]
128
+ dataset_types = notebook_templates[notebook_title]["dataset_types"]
129
+
130
  try:
131
  libraries = get_compatible_libraries(dataset_id)
132
  except Exception as err:
 
157
  has_numeric_columns = len(df.select_dtypes(include=["number"]).columns) > 0
158
  has_categoric_columns = len(df.select_dtypes(include=["object"]).columns) > 0
159
 
160
+ valid_dataset = False
161
+ if "text" in dataset_types and has_categoric_columns:
162
+ valid_dataset = True
163
+ if "numeric" in dataset_types and has_numeric_columns:
164
+ valid_dataset = True
165
+ if not valid_dataset:
 
 
 
 
166
  logging.error(
167
+ f"Dataset does not have the column types needed for this notebook which expects to have {dataset_types} data types."
168
  )
169
  return (
170
  "",
171
+ f"## ❌ This dataset does not have {dataset_types} columns, which are required for this notebook type ❌",
172
  )
173
 
174
  cells = replace_wildcards(
notebooks/eda.json CHANGED
@@ -1,7 +1,7 @@
1
  {
2
  "notebook_title": "Exploratory data analysis (EDA)",
3
  "notebook_type": "eda",
4
- "dataset_type": "numeric",
5
  "notebook_template": [
6
  {
7
  "cell_type": "markdown",
 
1
  {
2
  "notebook_title": "Exploratory data analysis (EDA)",
3
  "notebook_type": "eda",
4
+ "dataset_types": ["numeric", "text"],
5
  "notebook_template": [
6
  {
7
  "cell_type": "markdown",
notebooks/embeddings.json CHANGED
@@ -1,7 +1,7 @@
1
  {
2
  "notebook_title": "Text Embeddings",
3
  "notebook_type": "embeddings",
4
- "dataset_type": "text",
5
  "notebook_template": [
6
  {
7
  "cell_type": "markdown",
 
1
  {
2
  "notebook_title": "Text Embeddings",
3
  "notebook_type": "embeddings",
4
+ "dataset_types": ["text"],
5
  "notebook_template": [
6
  {
7
  "cell_type": "markdown",
notebooks/rag.json CHANGED
@@ -1,7 +1,7 @@
1
  {
2
  "notebook_title": "Retrieval-augmented generation (RAG)",
3
  "notebook_type": "rag",
4
- "dataset_type": "text",
5
  "notebook_template": [
6
  {
7
  "cell_type": "markdown",
 
1
  {
2
  "notebook_title": "Retrieval-augmented generation (RAG)",
3
  "notebook_type": "rag",
4
+ "dataset_types": ["text"],
5
  "notebook_template": [
6
  {
7
  "cell_type": "markdown",