From 2083edafa1c9870ad254dd1e46f5c107a72c5863 Mon Sep 17 00:00:00 2001 From: Sanjiv Das Date: Wed, 18 Dec 2024 16:48:01 -0800 Subject: [PATCH] Update generate.py --- .../jupyter_ai/chat_handlers/generate.py | 31 +++++++++++++++---- 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py index 4b8dffc3e..1f633acf2 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py @@ -1,3 +1,4 @@ +import ast import asyncio import os import time @@ -198,6 +199,15 @@ async def afill_outline(outline, llm, verbose=False): await asyncio.gather(*all_coros) +# Check if the content of the cell is python code or not +def is_not_python_code(source: str) -> bool: + try: + ast.parse(source) + return False + except (SyntaxError, ValueError): + return True + + def create_notebook(outline): """Create an nbformat Notebook object for a notebook outline.""" nbf = nbformat.v4 @@ -213,14 +223,23 @@ def create_notebook(outline): for code_block in section["code"].split("\n\n"): nb["cells"].append(nbf.new_code_cell(code_block)) - # Post process notebook for hanging cells: merge hanging cell with the previous cell - nb_cells = [] + # Post process notebook for hanging code cells: merge hanging cell with the previous cell + merged_cells = [] for cell in nb["cells"]: - if (cell["cell_type"] == "code") and (cell["source"][0] == " "): - nb_cells[-1]["source"] = nb_cells[-1]["source"] + "\n\n" + cell["source"] + # Fix a hanging code cell + follows_code_cell = merged_cells and merged_cells[-1]["cell_type"] == "code" + is_incomplete = cell["cell_type"] == "code" and cell["source"].startswith(" ") + if follows_code_cell and is_incomplete: + merged_cells[-1]["source"] = merged_cells[-1]["source"] + "\n\n" + cell["source"] else: - nb_cells.append(cell) - nb["cells"] = nb_cells + merged_cells.append(cell) + + # Fix code cells that should be markdown + for j in range(len(merged_cells)): + if merged_cells[j]["cell_type"]=="code" and is_not_python_code(merged_cells[j]["source"]): + merged_cells[j]["cell_type"] = "markdown" + + nb["cells"] = merged_cells return nb