Skip to content

Commit

Permalink
Fix matplotlib errors
Browse files Browse the repository at this point in the history
  • Loading branch information
Saransh-cpp committed Nov 13, 2023
1 parent ae22805 commit b5f74ad
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 12 deletions.
3 changes: 2 additions & 1 deletion pybamm/plotting/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#
import pybamm
from .quick_plot import ax_min, ax_max
from pybamm.util import have_optional_dependency


def plot(x, y, ax=None, testing=False, **kwargs):
Expand All @@ -25,7 +26,7 @@ def plot(x, y, ax=None, testing=False, **kwargs):
Keyword arguments, passed to plt.plot
"""
import matplotlib.pyplot as plt
plt = have_optional_dependency("matplotlib.pyplot")

if not isinstance(x, pybamm.Array):
raise TypeError("x must be 'pybamm.Array'")
Expand Down
3 changes: 2 additions & 1 deletion pybamm/plotting/plot2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#
import pybamm
from .quick_plot import ax_min, ax_max
from pybamm.util import have_optional_dependency


def plot2D(x, y, z, ax=None, testing=False, **kwargs):
Expand All @@ -25,7 +26,7 @@ def plot2D(x, y, z, ax=None, testing=False, **kwargs):
Whether to actually make the plot (turned off for unit tests)
"""
import matplotlib.pyplot as plt
plt = have_optional_dependency("matplotlib.pyplot")

if not isinstance(x, pybamm.Array):
raise TypeError("x must be 'pybamm.Array'")
Expand Down
3 changes: 2 additions & 1 deletion pybamm/plotting/plot_summary_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#
import numpy as np
import pybamm
from pybamm.util import have_optional_dependency


def plot_summary_variables(
Expand All @@ -25,7 +26,7 @@ def plot_summary_variables(
Keyword arguments, passed to plt.subplots.
"""
import matplotlib.pyplot as plt
plt = have_optional_dependency("matplotlib.pyplot")

if isinstance(solutions, pybamm.Solution):
solutions = [solutions]
Expand Down
4 changes: 3 additions & 1 deletion pybamm/plotting/plot_voltage_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#
import numpy as np

from pybamm.util import have_optional_dependency


def plot_voltage_components(
solution,
Expand Down Expand Up @@ -32,7 +34,7 @@ def plot_voltage_components(
Keyword arguments, passed to ax.fill_between
"""
import matplotlib.pyplot as plt
plt = have_optional_dependency("matplotlib.pyplot")

# Set a default value for alpha, the opacity
kwargs_fill = {"alpha": 0.6, **kwargs_fill}
Expand Down
18 changes: 10 additions & 8 deletions pybamm/plotting/quick_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import pybamm
from collections import defaultdict
from pybamm.util import have_optional_dependency


class LoopList(list):
Expand Down Expand Up @@ -46,7 +47,7 @@ def split_long_string(title, max_words=None):

def close_plots():
"""Close all open figures"""
import matplotlib.pyplot as plt
plt = have_optional_dependency("matplotlib", "pyplot")

plt.close("all")

Expand Down Expand Up @@ -469,9 +470,10 @@ def plot(self, t, dynamic=False):
Dimensional time (in 'time_units') at which to plot.
"""

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib import cm, colors
plt = have_optional_dependency("matplotlib.pyplot")
gridspec = have_optional_dependency("matplotlib.gridspec")
cm = have_optional_dependency("matplotlib", "cm")
colors = have_optional_dependency("matplotlib", "colors")

t_in_seconds = t * self.time_scaling_factor
self.fig = plt.figure(figsize=self.figsize)
Expand Down Expand Up @@ -668,8 +670,8 @@ def dynamic_plot(self, testing=False, step=None):
continuous_update=False,
)
else:
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider
plt = have_optional_dependency("matplotlib.pyplot")
Slider = have_optional_dependency("matplotlib.widgets", "Slider")

# create an initial plot at time self.min_t
self.plot(self.min_t, dynamic=True)
Expand Down Expand Up @@ -773,8 +775,8 @@ def create_gif(self, number_of_images=80, duration=0.1, output_filename="plot.gi
Name of the generated GIF file.
"""
import imageio.v2 as imageio
import matplotlib.pyplot as plt
imageio = have_optional_dependency("imageio.v2")
plt = have_optional_dependency("matplotlib.pyplot")

# time stamps at which the images/plots will be created
time_array = np.linspace(self.min_t, self.max_t, num=number_of_images)
Expand Down

0 comments on commit b5f74ad

Please sign in to comment.