diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py index a69b5ed28..6318e0979 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py @@ -1,3 +1,4 @@ +import ast import asyncio import os import time @@ -198,6 +199,15 @@ async def afill_outline(outline, llm, verbose=False): await asyncio.gather(*all_coros) +# Check if the content of the cell is python code or not +def is_not_python_code(source: str) -> bool: + try: + ast.parse(source) + except: + return True + return False + + def create_notebook(outline): """Create an nbformat Notebook object for a notebook outline.""" nbf = nbformat.v4 @@ -212,6 +222,26 @@ def create_notebook(outline): nb["cells"].append(nbf.new_markdown_cell("## " + section["title"])) for code_block in section["code"].split("\n\n"): nb["cells"].append(nbf.new_code_cell(code_block)) + + # Post process notebook for hanging code cells: merge hanging cell with the previous cell + merged_cells = [] + for cell in nb["cells"]: + # Fix a hanging code cell + follows_code_cell = merged_cells and merged_cells[-1]["cell_type"] == "code" + is_incomplete = cell["cell_type"] == "code" and cell["source"].startswith(" ") + if follows_code_cell and is_incomplete: + merged_cells[-1]["source"] = ( + merged_cells[-1]["source"] + "\n\n" + cell["source"] + ) + else: + merged_cells.append(cell) + + # Fix code cells that should be markdown + for cell in merged_cells: + if cell["cell_type"] == "code" and is_not_python_code(cell["source"]): + cell["cell_type"] = "markdown" + + nb["cells"] = merged_cells return nb