๐Ÿ“ฆ 3b1b / manim

๐Ÿ“„ scene_embed.py ยท 204 lines
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204from __future__ import annotations

import inspect
import pyperclip
import traceback

from IPython.terminal import pt_inputhooks
from IPython.terminal.embed import InteractiveShellEmbed

from manimlib.animation.fading import VFadeInThenOut
from manimlib.config import manim_config
from manimlib.constants import RED
from manimlib.mobject.mobject import Mobject
from manimlib.mobject.frame import FullScreenRectangle
from manimlib.module_loader import ModuleLoader


from typing import TYPE_CHECKING
if TYPE_CHECKING:
    from manimlib.scene.scene import Scene


class InteractiveSceneEmbed:
    def __init__(self, scene: Scene):
        self.scene = scene
        self.checkpoint_manager = CheckpointManager()

        self.shell = self.get_ipython_shell_for_embedded_scene()
        self.enable_gui()
        self.ensure_frame_update_post_cell()
        self.ensure_flash_on_error()
        if manim_config.embed.autoreload:
            self.auto_reload()

    def launch(self):
        self.shell()

    def get_ipython_shell_for_embedded_scene(self) -> InteractiveShellEmbed:
        """
        Create embedded IPython terminal configured to have access to
        the local namespace of the caller
        """
        # Triple back should take us to the context in a user's scene definition
        # which is calling "self.embed"
        caller_frame = inspect.currentframe().f_back.f_back.f_back

        # Update the module's namespace to include local variables
        module = ModuleLoader.get_module(caller_frame.f_globals["__file__"])
        module.__dict__.update(caller_frame.f_locals)
        module.__dict__.update(self.get_shortcuts())
        exception_mode = manim_config.embed.exception_mode

        return InteractiveShellEmbed(
            user_module=module,
            display_banner=False,
            xmode=exception_mode
        )

    def get_shortcuts(self):
        """
        A few custom shortcuts useful to have in the interactive shell namespace
        """
        scene = self.scene
        return dict(
            play=scene.play,
            wait=scene.wait,
            add=scene.add,
            remove=scene.remove,
            clear=scene.clear,
            focus=scene.focus,
            save_state=scene.save_state,
            undo=scene.undo,
            redo=scene.redo,
            i2g=scene.i2g,
            i2m=scene.i2m,
            checkpoint_paste=self.checkpoint_paste,
            clear_checkpoints=self.checkpoint_manager.clear_checkpoints,
            reload=self.reload_scene  # Defined below
        )

    def enable_gui(self):
        """Enables gui interactions during the embed"""
        def inputhook(context):
            while not context.input_is_ready():
                if not self.scene.is_window_closing():
                    self.scene.update_frame(dt=0)
            if self.scene.is_window_closing():
                self.shell.ask_exit()

        pt_inputhooks.register("manim", inputhook)
        self.shell.enable_gui("manim")

    def ensure_frame_update_post_cell(self):
        """Ensure the scene updates its frame after each ipython cell"""
        def post_cell_func(*args, **kwargs):
            if not self.scene.is_window_closing():
                self.scene.update_frame(dt=0, force_draw=True)

        self.shell.events.register("post_run_cell", post_cell_func)

    def ensure_flash_on_error(self):
        """Flash border, and potentially play sound, on exceptions"""
        def custom_exc(shell, etype, evalue, tb, tb_offset=None):
            # Show the error don't just swallow it
            shell.showtraceback((etype, evalue, tb), tb_offset=tb_offset)
            rect = FullScreenRectangle().set_stroke(RED, 30).set_fill(opacity=0)
            rect.fix_in_frame()
            self.scene.play(VFadeInThenOut(rect, run_time=0.5))

        self.shell.set_custom_exc((Exception,), custom_exc)

    def reload_scene(self, embed_line: int | None = None) -> None:
        """
        Reloads the scene just like the `manimgl` command would do with the
        same arguments that were provided for the initial startup. This allows
        for quick iteration during scene development since we don't have to exit
        the IPython kernel and re-run the `manimgl` command again. The GUI stays
        open during the reload.

        If `embed_line` is provided, the scene will be reloaded at that line
        number. This corresponds to the `linemarker` param of the
        `extract_scene.insert_embed_line_to_module()` method.

        Before reload, the scene is cleared and the entire state is reset, such
        that we can start from a clean slate. This is taken care of by the
        run_scenes function in __main__.py, which will catch the error raised by the
        `exit_raise` magic command that we invoke here.

        Note that we cannot define a custom exception class for this error,
        since the IPython kernel will swallow any exception. While we can catch
        such an exception in our custom exception handler registered with the
        `set_custom_exc` method, we cannot break out of the IPython shell by
        this means.
        """
        # Update the global run configuration.
        run_config = manim_config.run
        run_config.is_reload = True
        if embed_line:
            run_config.embed_line = embed_line

        print("Reloading...")
        self.shell.run_line_magic("exit_raise", "")

    def auto_reload(self):
        """Enables reload the shell's module before all calls"""
        def pre_cell_func(*args, **kwargs):
            new_mod = ModuleLoader.get_module(self.shell.user_module.__file__, is_during_reload=True)
            self.shell.user_ns.update(vars(new_mod))

        self.shell.events.register("pre_run_cell", pre_cell_func)

    def checkpoint_paste(
        self,
        skip: bool = False,
        record: bool = False,
        progress_bar: bool = True
    ):
        with self.scene.temp_config_change(skip, record, progress_bar):
            self.checkpoint_manager.checkpoint_paste(self.shell, self.scene)


class CheckpointManager:
    def __init__(self):
        self.checkpoint_states: dict[str, list[tuple[Mobject, Mobject]]] = dict()

    def checkpoint_paste(self, shell, scene):
        """
        Used during interactive development to run (or re-run)
        a block of scene code.

        If the copied selection starts with a comment, this will
        revert to the state of the scene the first time this function
        was called on a block of code starting with that comment.
        """
        code_string = pyperclip.paste()
        checkpoint_key = self.get_leading_comment(code_string)
        self.handle_checkpoint_key(scene, checkpoint_key)
        shell.run_cell(code_string)

    @staticmethod
    def get_leading_comment(code_string: str) -> str:
        leading_line = code_string.partition("\n")[0].lstrip()
        if leading_line.startswith("#"):
            return leading_line
        return ""

    def handle_checkpoint_key(self, scene, key: str):
        if not key:
            return
        elif key in self.checkpoint_states:
            # Revert to checkpoint
            scene.restore_state(self.checkpoint_states[key])

            # Clear out any saved states that show up later
            all_keys = list(self.checkpoint_states.keys())
            index = all_keys.index(key)
            for later_key in all_keys[index + 1:]:
                self.checkpoint_states.pop(later_key)
        else:
            self.checkpoint_states[key] = scene.get_state()

    def clear_checkpoints(self):
        self.checkpoint_states = dict()