Source code for qutip.visualization

"""
Functions for visualizing results of quantum dynamics simulations,
visualizations of quantum states and processes.
"""

__all__ = ['hinton', 'sphereplot', 'energy_level_diagram',
           'plot_energy_levels', 'fock_distribution',
           'plot_fock_distribution', 'wigner_fock_distribution',
           'plot_wigner_fock_distribution', 'plot_wigner',
           'plot_expectation_values', 'plot_spin_distribution_2d',
           'plot_spin_distribution_3d', 'plot_qubism', 'plot_schmidt',
           'complex_array_to_rgb', 'matrix_histogram',
           'matrix_histogram_complex', 'sphereplot', 'plot_wigner_sphere']

import warnings
import itertools as it
import numpy as np
from numpy import pi, array, sin, cos, angle, log2

from packaging.version import parse as parse_version

from qutip.qobj import Qobj, isket
from qutip.states import ket2dm
from qutip.wigner import wigner
from qutip.tensor import tensor
from qutip.matplotlib_utilities import complex_phase_cmap
from qutip.superoperator import vector_to_operator
from qutip.superop_reps import _super_to_superpauli, _isqubitdims

from qutip import settings

try:
    import matplotlib.pyplot as plt
    import matplotlib as mpl
    from matplotlib import cm
    from mpl_toolkits.mplot3d import Axes3D

    # Define a custom _axes3D function based on the matplotlib version.
    # The auto_add_to_figure keyword is new for matplotlib>=3.4.
    if parse_version(mpl.__version__) >= parse_version('3.4'):
        def _axes3D(fig, *args, **kwargs):
            ax = Axes3D(fig, *args, auto_add_to_figure=False, **kwargs)
            return fig.add_axes(ax)
    else:
        def _axes3D(*args, **kwargs):
            return Axes3D(*args, **kwargs)
except:
    pass


[docs]def plot_wigner_sphere(fig, ax, wigner, reflections): """Plots a coloured Bloch sphere. Parameters ---------- fig : :obj:`matplotlib.figure.Figure` An instance of :obj:`~matplotlib.figure.Figure`. ax : :obj:`matplotlib.axes.Axes` An axes instance in the given figure. wigner : list of float The wigner transformation at `steps` different theta and phi. reflections : bool If the reflections of the sphere should be plotted as well. Notes ------ Special thanks to Russell P Rundle for writing this function. """ ax.set_xlabel("x") ax.set_ylabel("y") ax.set_zlabel("z") steps = len(wigner) theta = np.linspace(0, np.pi, steps) phi = np.linspace(0, 2 * np.pi, steps) x = np.outer(np.sin(theta), np.cos(phi)) y = np.outer(np.sin(theta), np.sin(phi)) z = np.outer(np.cos(theta), np.ones(steps)) wigner = np.real(wigner) wigner_max = np.real(np.amax(np.abs(wigner))) wigner_c1 = cm.seismic_r((wigner + wigner_max) / (2 * wigner_max)) # Plot coloured Bloch sphere: ax.plot_surface(x, y, z, facecolors=wigner_c1, vmin=-wigner_max, vmax=wigner_max, rcount=steps, ccount=steps, linewidth=0, zorder=0.5, antialiased=None) if reflections: wigner_c2 = cm.seismic_r((wigner[0:steps, 0:steps]+wigner_max) / (2*wigner_max)) # bottom wigner_c3 = cm.seismic_r((wigner[0:steps, 0:steps]+wigner_max) / (2*wigner_max)) # side wigner_c4 = cm.seismic_r((wigner[0:steps, 0:steps]+wigner_max) / (2*wigner_max)) # back # Plot bottom reflection: ax.plot_surface(x[0:steps, 0:steps], y[0:steps, 0:steps], -1.5*np.ones((steps, steps)), facecolors=wigner_c2, vmin=-wigner_max, vmax=wigner_max, rcount=steps/2, ccount=steps/2, linewidth=0, zorder=0.5, antialiased=False) # Plot side reflection: ax.plot_surface(-1.5*np.ones((steps, steps)), y[0:steps, 0:steps], z[0:steps, 0:steps], facecolors=wigner_c3, vmin=-wigner_max, vmax=wigner_max, rcount=steps/2, ccount=steps/2, linewidth=0, zorder=0.5, antialiased=False) # Plot back reflection: ax.plot_surface(x[0:steps, 0:steps], 1.5*np.ones((steps, steps)), z[0:steps, 0:steps], facecolors=wigner_c4, vmin=-wigner_max, vmax=wigner_max, rcount=steps/2, ccount=steps/2, linewidth=0, zorder=0.5, antialiased=False) # Create colourbar: m = cm.ScalarMappable(cmap=cm.seismic_r) m.set_array([-wigner_max, wigner_max]) plt.colorbar(m, shrink=0.5, aspect=10) plt.show()
# Adopted from the SciPy Cookbook. def _blob(x, y, w, w_max, area, color_fn, ax=None): """ Draws a square-shaped blob with the given area (< 1) at the given coordinates. """ hs = np.sqrt(area) / 2 xcorners = array([x - hs, x + hs, x + hs, x - hs]) ycorners = array([y - hs, y - hs, y + hs, y + hs]) if ax is not None: handle = ax else: handle = plt handle.fill(xcorners, ycorners, color=color_fn(w)) def _cb_labels(left_dims): """Creates plot labels for matrix elements in the computational basis. Parameters ---------- left_dims : flat list of ints Dimensions of the left index of a density operator. E. g. [2, 3] for a qubit tensored with a qutrit. Returns ------- left_labels, right_labels : lists of strings Labels for the left and right indices of a density operator (kets and bras, respectively). """ # FIXME: assumes dims, such that we only need left_dims == dims[0]. basis_labels = list(map(",".join, it.product(*[ map(str, range(dim)) for dim in left_dims ]))) return [ map(fmt.format, basis_labels) for fmt in ( r"$\langle{}|$", r"$|{}\rangle$", ) ] # Adopted from the SciPy Cookbook.
[docs]def hinton(rho, xlabels=None, ylabels=None, title=None, ax=None, cmap=None, label_top=True, color_style="scaled"): """Draws a Hinton diagram for visualizing a density matrix or superoperator. Parameters ---------- rho : qobj Input density matrix or superoperator. xlabels : list of strings or False list of x labels ylabels : list of strings or False list of y labels title : string title of the plot (optional) ax : a matplotlib axes instance The axes context in which the plot will be drawn. cmap : a matplotlib colormap instance Color map to use when plotting. label_top : bool If True, x-axis labels will be placed on top, otherwise they will appear below the plot. color_style : string Determines how colors are assigned to each square: - If set to ``"scaled"`` (default), each color is chosen by passing the absolute value of the corresponding matrix element into `cmap` with the sign of the real part. - If set to ``"threshold"``, each square is plotted as the maximum of `cmap` for the positive real part and as the minimum for the negative part of the matrix element; note that this generalizes `"threshold"` to complex numbers. - If set to ``"phase"``, each color is chosen according to the angle of the corresponding matrix element. Returns ------- fig, ax : tuple A tuple of the matplotlib figure and axes instances used to produce the figure. Raises ------ ValueError Input argument is not a quantum object. Examples -------- >>> import qutip >>> >>> dm = qutip.rand_dm(4) >>> fig, ax = qutip.hinton(dm) >>> fig.show() >>> >>> qutip.settings.colorblind_safe = True >>> fig, ax = qutip.hinton(dm, color_style="threshold") >>> fig.show() >>> qutip.settings.colorblind_safe = False >>> >>> fig, ax = qutip.hinton(dm, color_style="phase") >>> fig.show() """ # Apply default colormaps. # TODO: abstract this away into something that makes default # colormaps. cmap = ( (cm.Greys_r if settings.colorblind_safe else cm.RdBu) if cmap is None else cmap ) # Extract plotting data W from the input. if isinstance(rho, Qobj): if rho.isoper: W = rho.full() # Create default labels if none are given. if xlabels is None or ylabels is None: labels = _cb_labels(rho.dims[0]) xlabels = xlabels if xlabels is not None else list(labels[0]) ylabels = ylabels if ylabels is not None else list(labels[1]) elif rho.isoperket: W = vector_to_operator(rho).full() elif rho.isoperbra: W = vector_to_operator(rho.dag()).full() elif rho.issuper: if not _isqubitdims(rho.dims): raise ValueError("Hinton plots of superoperators are " "currently only supported for qubits.") # Convert to a superoperator in the Pauli basis, # so that all the elements are real. sqobj = _super_to_superpauli(rho) nq = int(log2(sqobj.shape[0]) / 2) W = sqobj.full().T # Create default labels, too. if (xlabels is None) or (ylabels is None): labels = list(map("".join, it.product("IXYZ", repeat=nq))) xlabels = xlabels if xlabels is not None else labels ylabels = ylabels if ylabels is not None else labels else: raise ValueError( "Input quantum object must be an operator or superoperator." ) else: W = rho if ax is None: fig, ax = plt.subplots(1, 1, figsize=(8, 6)) else: fig = None if not (xlabels or ylabels): ax.axis('off') if title: ax.set_title(title) ax.axis('equal') ax.set_frame_on(False) height, width = W.shape w_max = 1.25 * max(abs(np.array(W)).flatten()) if w_max <= 0.0: w_max = 1.0 # Set color_fn here. if color_style == "scaled": def color_fn(w): w = np.abs(w) * np.sign(np.real(w)) return cmap(int((w + w_max) * 256 / (2 * w_max))) elif color_style == "threshold": def color_fn(w): w = np.real(w) return cmap(255 if w > 0 else 0) elif color_style == "phase": def color_fn(w): return cmap(int(255 * (np.angle(w) / 2 / np.pi + 0.5))) else: raise ValueError( "Unknown color style {} for Hinton diagrams.".format(color_style) ) ax.fill(array([0, width, width, 0]), array([0, 0, height, height]), color=cmap(128)) for x in range(width): for y in range(height): _x = x + 1 _y = y + 1 _blob( _x - 0.5, height - _y + 0.5, W[y, x], w_max, min(1, abs(W[y, x]) / w_max), color_fn=color_fn, ax=ax) # color axis vmax = np.pi if color_style == "phase" else abs(W).max() norm = mpl.colors.Normalize(-vmax, vmax) cax, kw = mpl.colorbar.make_axes(ax, shrink=0.75, pad=.1) mpl.colorbar.ColorbarBase(cax, norm=norm, cmap=cmap) xtics = 0.5 + np.arange(width) # x axis ax.xaxis.set_major_locator(plt.FixedLocator(xtics)) if xlabels: nxlabels = len(xlabels) if nxlabels != len(xtics): raise ValueError(f"got {nxlabels} xlabels but needed {len(xtics)}") ax.set_xticklabels(xlabels) if label_top: ax.xaxis.tick_top() ax.tick_params(axis='x', labelsize=14) # y axis ytics = 0.5 + np.arange(height) ax.yaxis.set_major_locator(plt.FixedLocator(ytics)) if ylabels: nylabels = len(ylabels) if nylabels != len(ytics): raise ValueError(f"got {nylabels} ylabels but needed {len(ytics)}") ax.set_yticklabels(list(reversed(ylabels))) ax.tick_params(axis='y', labelsize=14) return fig, ax
[docs]def sphereplot(theta, phi, values, fig=None, ax=None, save=False): """Plots a matrix of values on a sphere Parameters ---------- theta : float Angle with respect to z-axis phi : float Angle in x-y plane values : array Data set to be plotted fig : a matplotlib Figure instance The Figure canvas in which the plot will be drawn. ax : a matplotlib axes instance The axes context in which the plot will be drawn. save : bool {False , True} Whether to save the figure or not Returns ------- fig, ax : tuple A tuple of the matplotlib figure and axes instances used to produce the figure. """ if fig is None or ax is None: fig = plt.figure() ax = _axes3D(fig) thetam, phim = np.meshgrid(theta, phi) xx = sin(thetam) * cos(phim) yy = sin(thetam) * sin(phim) zz = cos(thetam) r = array(abs(values)) ph = angle(values) # normalize color range based on phase angles in list ph nrm = mpl.colors.Normalize(ph.min(), ph.max()) # plot with facecolors set to cm.jet colormap normalized to nrm ax.plot_surface(r * xx, r * yy, r * zz, rstride=1, cstride=1, facecolors=cm.jet(nrm(ph)), linewidth=0) # create new axes on plot for colorbar and shrink it a bit. # pad shifts location of bar with repsect to the main plot cax, kw = mpl.colorbar.make_axes(ax, shrink=.66, pad=.02) # create new colorbar in axes cax with cm jet and normalized to nrm like # our facecolors cb1 = mpl.colorbar.ColorbarBase(cax, cmap=cm.jet, norm=nrm) # add our colorbar label cb1.set_label('Angle') if save: plt.savefig("sphereplot.png") return fig, ax
def _remove_margins(axis): """ removes margins about z = 0 and improves the style by monkey patching """ def _get_coord_info_new(renderer): mins, maxs, centers, deltas, tc, highs = \ _get_coord_info_old(renderer) mins += deltas / 4 maxs -= deltas / 4 return mins, maxs, centers, deltas, tc, highs _get_coord_info_old = axis._get_coord_info axis._get_coord_info = _get_coord_info_new def _truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100): """ truncates portion of a colormap and returns the new one """ if isinstance(cmap, str): cmap = plt.get_cmap(cmap) new_cmap = mpl.colors.LinearSegmentedColormap.from_list( 'trunc({n},{a:.2f},{b:.2f})'.format( n=cmap.name, a=minval, b=maxval), cmap(np.linspace(minval, maxval, n))) return new_cmap def _stick_to_planes(stick, azim, ax, M, spacing): """adjusts xlim and ylim in way that bars will Stick to xz and yz planes """ if stick is True: azim = azim % 360 if 0 <= azim <= 90: ax.set_ylim(1 - .5,) ax.set_xlim(1 - .5,) elif 90 < azim <= 180: ax.set_ylim(1 - .5,) ax.set_xlim(0, M.shape[0] + (.5 - spacing)) elif 180 < azim <= 270: ax.set_ylim(0, M.shape[1] + (.5 - spacing)) ax.set_xlim(0, M.shape[0] + (.5 - spacing)) elif 270 < azim < 360: ax.set_ylim(0, M.shape[1] + (.5 - spacing)) ax.set_xlim(1 - .5,) def _update_yaxis(spacing, M, ax, ylabels): """ updates the y-axis """ ytics = [x + (1 - (spacing / 2)) for x in range(M.shape[1])] if parse_version(mpl.__version__) >= parse_version("3.8"): ax.axes.yaxis.set_major_locator(plt.FixedLocator(ytics)) else: ax.axes.w_yaxis.set_major_locator(plt.FixedLocator(ytics)) if ylabels: nylabels = len(ylabels) if nylabels != len(ytics): raise ValueError(f"got {nylabels} ylabels but needed {len(ytics)}") ax.set_yticklabels(ylabels) else: ax.set_yticklabels([str(y + 1) for y in range(M.shape[1])]) ax.set_yticklabels([str(i) for i in range(M.shape[1])]) ax.tick_params(axis='y', labelsize=14) ax.set_yticks([y + (1 - (spacing / 2)) for y in range(M.shape[1])]) def _update_xaxis(spacing, M, ax, xlabels): """ updates the x-axis """ xtics = [x + (1 - (spacing / 2)) for x in range(M.shape[1])] if parse_version(mpl.__version__) >= parse_version("3.8"): ax.axes.xaxis.set_major_locator(plt.FixedLocator(xtics)) else: ax.axes.w_xaxis.set_major_locator(plt.FixedLocator(xtics)) if xlabels: nxlabels = len(xlabels) if nxlabels != len(xtics): raise ValueError(f"got {nxlabels} xlabels but needed {len(xtics)}") ax.set_xticklabels(xlabels) else: ax.set_xticklabels([str(x + 1) for x in range(M.shape[0])]) ax.set_xticklabels([str(i) for i in range(M.shape[0])]) ax.tick_params(axis='x', labelsize=14) ax.set_xticks([x + (1 - (spacing / 2)) for x in range(M.shape[0])]) def _update_zaxis(ax, z_min, z_max, zticks): """ updates the z-axis """ if parse_version(mpl.__version__) >= parse_version("3.8"): ax.axes.zaxis.set_major_locator(plt.IndexLocator(1, 0.5)) else: ax.axes.w_zaxis.set_major_locator(plt.IndexLocator(1, 0.5)) if isinstance(zticks, list): ax.set_zticks(zticks) ax.set_zlim3d([min(z_min, 0), z_max])
[docs]def matrix_histogram(M, xlabels=None, ylabels=None, title=None, limits=None, colorbar=True, fig=None, ax=None, options=None): """ Draw a histogram for the matrix M, with the given x and y labels and title. Parameters ---------- M : Matrix of Qobj The matrix to visualize xlabels : list of strings list of x labels ylabels : list of strings list of y labels title : string title of the plot (optional) limits : list/array with two float numbers The z-axis limits [min, max] (optional) ax : a matplotlib axes instance The axes context in which the plot will be drawn. colorbar : bool (default: True) show colorbar options : dict A dictionary containing extra options for the plot. The names (keys) and values of the options are described below: 'zticks' : list of numbers A list of z-axis tick locations. 'cmap' : string (default: 'jet') The name of the color map to use. 'cmap_min' : float (default: 0.0) The lower bound to truncate the color map at. A value in range 0 - 1. The default, 0, leaves the lower bound of the map unchanged. 'cmap_max' : float (default: 1.0) The upper bound to truncate the color map at. A value in range 0 - 1. The default, 1, leaves the upper bound of the map unchanged. 'bars_spacing' : float (default: 0.1) spacing between bars. 'bars_alpha' : float (default: 1.) transparency of bars, should be in range 0 - 1 'bars_lw' : float (default: 0.5) linewidth of bars' edges. 'bars_edgecolor' : color (default: 'k') The colors of the bars' edges. Examples: 'k', (0.1, 0.2, 0.5) or '#0f0f0f80'. 'shade' : bool (default: True) Whether to shade the dark sides of the bars (True) or not (False). The shading is relative to plot's source of light. 'azim' : float The azimuthal viewing angle. 'elev' : float The elevation viewing angle. 'proj_type' : string (default: 'ortho' if ax is not passed) The type of projection ('ortho' or 'persp') 'stick' : bool (default: False) Changes xlim and ylim in such a way that bars next to XZ and YZ planes will stick to those planes. This option has no effect if ``ax`` is passed as a parameter. 'cbar_pad' : float (default: 0.04) The fraction of the original axes between the colorbar and the new image axes. (i.e. the padding between the 3D figure and the colorbar). 'cbar_to_z' : bool (default: False) Whether to set the color of maximum and minimum z-values to the maximum and minimum colors in the colorbar (True) or not (False). 'figsize' : tuple of two numbers The size of the figure. Returns : ------- fig, ax : tuple A tuple of the matplotlib figure and axes instances used to produce the figure. Raises ------ ValueError Input argument is not valid. """ # default options default_opts = {'figsize': None, 'cmap': 'jet', 'cmap_min': 0., 'cmap_max': 1., 'zticks': None, 'bars_spacing': 0.2, 'bars_alpha': 1., 'bars_lw': 0.5, 'bars_edgecolor': 'k', 'shade': False, 'azim': -35, 'elev': 35, 'proj_type': 'ortho', 'stick': False, 'cbar_pad': 0.04, 'cbar_to_z': False} # update default_opts from input options if options is None: pass elif isinstance(options, dict): # check if keys in options dict are valid if set(options) - set(default_opts): raise ValueError("invalid key(s) found in options: " f"{', '.join(set(options) - set(default_opts))}") else: # updating default options default_opts.update(options) else: raise ValueError("options must be a dictionary") if isinstance(M, Qobj): # extract matrix data from Qobj M = M.full() n = np.size(M) xpos, ypos = np.meshgrid(range(M.shape[0]), range(M.shape[1])) xpos = xpos.T.flatten() + 0.5 ypos = ypos.T.flatten() + 0.5 zpos = np.zeros(n) dx = dy = (1 - default_opts['bars_spacing']) * np.ones(n) dz = np.real(M.flatten()) if isinstance(limits, list) and len(limits) == 2: z_min = limits[0] z_max = limits[1] else: z_min = min(dz) z_max = max(dz) if z_min == z_max: z_min -= 0.1 z_max += 0.1 if default_opts['cbar_to_z']: norm = mpl.colors.Normalize(min(dz), max(dz)) else: norm = mpl.colors.Normalize(z_min, z_max) cmap = _truncate_colormap(default_opts['cmap'], default_opts['cmap_min'], default_opts['cmap_max']) colors = cmap(norm(dz)) if ax is None: fig = plt.figure(figsize=default_opts['figsize']) ax = _axes3D(fig, azim=default_opts['azim'] % 360, elev=default_opts['elev'] % 360) ax.set_proj_type(default_opts['proj_type']) ax.bar3d(xpos, ypos, zpos, dx, dy, dz, color=colors, edgecolors=default_opts['bars_edgecolor'], linewidths=default_opts['bars_lw'], alpha=default_opts['bars_alpha'], shade=default_opts['shade']) # remove vertical lines on xz and yz plane ax.yaxis._axinfo["grid"]['linewidth'] = 0 ax.xaxis._axinfo["grid"]['linewidth'] = 0 if title: ax.set_title(title) # x axis _update_xaxis(default_opts['bars_spacing'], M, ax, xlabels) # y axis _update_yaxis(default_opts['bars_spacing'], M, ax, ylabels) # z axis _update_zaxis(ax, z_min, z_max, default_opts['zticks']) # stick to xz and yz plane _stick_to_planes(default_opts['stick'], default_opts['azim'], ax, M, default_opts['bars_spacing']) # color axis if colorbar: cax, kw = mpl.colorbar.make_axes(ax, shrink=.75, pad=default_opts['cbar_pad']) mpl.colorbar.ColorbarBase(cax, cmap=cmap, norm=norm) # removing margins _remove_margins(ax.xaxis) _remove_margins(ax.yaxis) _remove_margins(ax.zaxis) return fig, ax
[docs]def matrix_histogram_complex(M, xlabels=None, ylabels=None, title=None, limits=None, phase_limits=None, colorbar=True, fig=None, ax=None, threshold=None): """ Draw a histogram for the amplitudes of matrix M, using the argument of each element for coloring the bars, with the given x and y labels and title. Parameters ---------- M : Matrix of Qobj The matrix to visualize xlabels : list of strings list of x labels ylabels : list of strings list of y labels title : string title of the plot (optional) limits : list/array with two float numbers The z-axis limits [min, max] (optional) phase_limits : list/array with two float numbers The phase-axis (colorbar) limits [min, max] (optional) ax : a matplotlib axes instance The axes context in which the plot will be drawn. threshold: float (None) Threshold for when bars of smaller height should be transparent. If not set, all bars are colored according to the color map. Returns ------- fig, ax : tuple A tuple of the matplotlib figure and axes instances used to produce the figure. Raises ------ ValueError Input argument is not valid. """ if isinstance(M, Qobj): # extract matrix data from Qobj M = M.full() n = np.size(M) xpos, ypos = np.meshgrid(range(M.shape[0]), range(M.shape[1])) xpos = xpos.T.flatten() - 0.5 ypos = ypos.T.flatten() - 0.5 zpos = np.zeros(n) dx = dy = 0.8 * np.ones(n) Mvec = M.flatten() dz = abs(Mvec) # make small numbers real, to avoid random colors idx, = np.where(abs(Mvec) < 0.001) Mvec[idx] = abs(Mvec[idx]) if phase_limits: # check that limits is a list type phase_min = phase_limits[0] phase_max = phase_limits[1] else: phase_min = -pi phase_max = pi norm = mpl.colors.Normalize(phase_min, phase_max) cmap = complex_phase_cmap() colors = cmap(norm(angle(Mvec))) if threshold is not None: colors[:, 3] = 1 * (dz > threshold) if ax is None: fig = plt.figure() ax = _axes3D(fig, azim=-35, elev=35) ax.bar3d(xpos, ypos, zpos, dx, dy, dz, color=colors) if title: ax.set_title(title) # x axis xtics = -0.5 + np.arange(M.shape[0]) if parse_version(mpl.__version__) >= parse_version("3.8"): ax.axes.xaxis.set_major_locator(plt.FixedLocator(xtics)) else: ax.axes.w_xaxis.set_major_locator(plt.FixedLocator(xtics)) if xlabels: nxlabels = len(xlabels) if nxlabels != len(xtics): raise ValueError(f"got {nxlabels} xlabels but needed {len(xtics)}") ax.set_xticklabels(xlabels) ax.tick_params(axis='x', labelsize=12) # y axis ytics = -0.5 + np.arange(M.shape[1]) if parse_version(mpl.__version__) >= parse_version("3.8"): ax.axes.yaxis.set_major_locator(plt.FixedLocator(ytics)) else: ax.axes.w_yaxis.set_major_locator(plt.FixedLocator(ytics)) if ylabels: nylabels = len(ylabels) if nylabels != len(ytics): raise ValueError(f"got {nylabels} ylabels but needed {len(ytics)}") ax.set_yticklabels(ylabels) ax.tick_params(axis='y', labelsize=12) # z axis if limits and isinstance(limits, list): ax.set_zlim3d(limits) else: ax.set_zlim3d([0, 1]) # use min/max # ax.set_zlabel('abs') # color axis if colorbar: cax, kw = mpl.colorbar.make_axes(ax, shrink=.75, pad=.0) cb = mpl.colorbar.ColorbarBase(cax, cmap=cmap, norm=norm) cb.set_ticks([-pi, -pi / 2, 0, pi / 2, pi]) cb.set_ticklabels( (r'$-\pi$', r'$-\pi/2$', r'$0$', r'$\pi/2$', r'$\pi$')) cb.set_label('arg') return fig, ax
[docs]def plot_energy_levels(H_list, N=0, labels=None, show_ylabels=False, figsize=(8, 12), fig=None, ax=None): """ Plot the energy level diagrams for a list of Hamiltonians. Include up to N energy levels. For each element in H_list, the energy levels diagram for the cummulative Hamiltonian sum(H_list[0:n]) is plotted, where n is the index of an element in H_list. Parameters ---------- H_list : List of Qobj A list of Hamiltonians. labels : List of string A list of labels for each Hamiltonian show_ylabels : Bool (default False) Show y labels to the left of energy levels of the initial Hamiltonian. N : int The number of energy levels to plot figsize : tuple (int,int) The size of the figure (width, height). fig : a matplotlib Figure instance The Figure canvas in which the plot will be drawn. ax : a matplotlib axes instance The axes context in which the plot will be drawn. Returns ------- fig, ax : tuple A tuple of the matplotlib figure and axes instances used to produce the figure. Raises ------ ValueError Input argument is not valid. """ if not isinstance(H_list, list): raise ValueError("H_list must be a list of Qobj instances") if not fig and not ax: fig, ax = plt.subplots(1, 1, figsize=figsize) H = H_list[0] N = H.shape[0] if N == 0 else min(H.shape[0], N) xticks = [] yticks = [] x = 0 evals0 = H.eigenenergies(eigvals=N) for e_idx, e in enumerate(evals0[:N]): ax.plot([x, x + 2], np.array([1, 1]) * e, 'b', linewidth=2) yticks.append(e) xticks.append(x + 1) x += 2 for H1 in H_list[1:]: H = H + H1 evals1 = H.eigenenergies() for e_idx, e in enumerate(evals1[:N]): ax.plot([x, x + 1], np.array([evals0[e_idx], e]), 'k:') x += 1 for e_idx, e in enumerate(evals1[:N]): ax.plot([x, x + 2], np.array([1, 1]) * e, 'b', linewidth=2) xticks.append(x + 1) x += 2 evals0 = evals1 ax.set_frame_on(False) if show_ylabels: yticks = np.unique(np.around(yticks, 1)) ax.set_yticks(yticks) else: ax.axes.get_yaxis().set_visible(False) if labels: ax.get_xaxis().tick_bottom() ax.set_xticks(xticks) ax.set_xticklabels(labels, fontsize=16) else: ax.axes.get_xaxis().set_visible(False) return fig, ax
def energy_level_diagram(H_list, N=0, labels=None, show_ylabels=False, figsize=(8, 12), fig=None, ax=None): warnings.warn("Deprecated: Use plot_energy_levels") return plot_energy_levels(H_list, N=N, labels=labels, show_ylabels=show_ylabels, figsize=figsize, fig=fig, ax=ax)
[docs]def plot_fock_distribution(rho, offset=0, fig=None, ax=None, figsize=(8, 6), title=None, unit_y_range=True): """ Plot the Fock distribution for a density matrix (or ket) that describes an oscillator mode. Parameters ---------- rho : :class:`qutip.Qobj` The density matrix (or ket) of the state to visualize. fig : a matplotlib Figure instance The Figure canvas in which the plot will be drawn. ax : a matplotlib axes instance The axes context in which the plot will be drawn. title : string An optional title for the figure. figsize : (width, height) The size of the matplotlib figure (in inches) if it is to be created (that is, if no 'fig' and 'ax' arguments are passed). Returns ------- fig, ax : tuple A tuple of the matplotlib figure and axes instances used to produce the figure. """ if not fig and not ax: fig, ax = plt.subplots(1, 1, figsize=figsize) if isket(rho): rho = ket2dm(rho) N = rho.shape[0] ax.bar(np.arange(offset, offset + N), np.real(rho.diag()), color="green", alpha=0.6, width=0.8) if unit_y_range: ax.set_ylim(0, 1) ax.set_xlim(-.5 + offset, N + offset) ax.set_xlabel('Fock number', fontsize=12) ax.set_ylabel('Occupation probability', fontsize=12) if title: ax.set_title(title) return fig, ax
def fock_distribution(rho, offset=0, fig=None, ax=None, figsize=(8, 6), title=None, unit_y_range=True): warnings.warn("Deprecated: Use plot_fock_distribution") return plot_fock_distribution(rho, offset=offset, fig=fig, ax=ax, figsize=figsize, title=title, unit_y_range=unit_y_range)
[docs]def plot_wigner(rho, fig=None, ax=None, figsize=(6, 6), cmap=None, alpha_max=7.5, colorbar=False, method='clenshaw', projection='2d'): """ Plot the the Wigner function for a density matrix (or ket) that describes an oscillator mode. Parameters ---------- rho : :class:`qutip.Qobj` The density matrix (or ket) of the state to visualize. fig : a matplotlib Figure instance The Figure canvas in which the plot will be drawn. ax : a matplotlib axes instance The axes context in which the plot will be drawn. figsize : (width, height) The size of the matplotlib figure (in inches) if it is to be created (that is, if no 'fig' and 'ax' arguments are passed). cmap : a matplotlib cmap instance The colormap. alpha_max : float The span of the x and y coordinates (both [-alpha_max, alpha_max]). colorbar : bool Whether (True) or not (False) a colorbar should be attached to the Wigner function graph. method : string {'clenshaw', 'iterative', 'laguerre', 'fft'} The method used for calculating the wigner function. See the documentation for qutip.wigner for details. projection: string {'2d', '3d'} Specify whether the Wigner function is to be plotted as a contour graph ('2d') or surface plot ('3d'). Returns ------- fig, ax : tuple A tuple of the matplotlib figure and axes instances used to produce the figure. """ if not fig and not ax: if projection == '2d': fig, ax = plt.subplots(1, 1, figsize=figsize) elif projection == '3d': fig = plt.figure(figsize=figsize) ax = fig.add_subplot(1, 1, 1, projection='3d') else: raise ValueError('Unexpected value of projection keyword argument') if isket(rho): rho = ket2dm(rho) xvec = np.linspace(-alpha_max, alpha_max, 200) W0 = wigner(rho, xvec, xvec, method=method) W, yvec = W0 if isinstance(W0, tuple) else (W0, xvec) wlim = abs(W).max() if cmap is None: cmap = cm.get_cmap('RdBu') if projection == '2d': cf = ax.contourf(xvec, yvec, W, 100, norm=mpl.colors.Normalize(-wlim, wlim), cmap=cmap) elif projection == '3d': X, Y = np.meshgrid(xvec, xvec) cf = ax.plot_surface(X, Y, W0, rstride=5, cstride=5, linewidth=0.5, norm=mpl.colors.Normalize(-wlim, wlim), cmap=cmap) else: raise ValueError('Unexpected value of projection keyword argument.') if xvec is not yvec: ax.set_ylim(xvec.min(), xvec.max()) ax.set_xlabel(r'$\rm{Re}(\alpha)$', fontsize=12) ax.set_ylabel(r'$\rm{Im}(\alpha)$', fontsize=12) if colorbar: fig.colorbar(cf, ax=ax) ax.set_title("Wigner function", fontsize=12) return fig, ax
[docs]def plot_wigner_fock_distribution(rho, fig=None, axes=None, figsize=(8, 4), cmap=None, alpha_max=7.5, colorbar=False, method='iterative', projection='2d'): """ Plot the Fock distribution and the Wigner function for a density matrix (or ket) that describes an oscillator mode. Parameters ---------- rho : :class:`qutip.Qobj` The density matrix (or ket) of the state to visualize. fig : a matplotlib Figure instance The Figure canvas in which the plot will be drawn. axes : a list of two matplotlib axes instances The axes context in which the plot will be drawn. figsize : (width, height) The size of the matplotlib figure (in inches) if it is to be created (that is, if no 'fig' and 'ax' arguments are passed). cmap : a matplotlib cmap instance The colormap. alpha_max : float The span of the x and y coordinates (both [-alpha_max, alpha_max]). colorbar : bool Whether (True) or not (False) a colorbar should be attached to the Wigner function graph. method : string {'iterative', 'laguerre', 'fft'} The method used for calculating the wigner function. See the documentation for qutip.wigner for details. projection: string {'2d', '3d'} Specify whether the Wigner function is to be plotted as a contour graph ('2d') or surface plot ('3d'). Returns ------- fig, ax : tuple A tuple of the matplotlib figure and axes instances used to produce the figure. """ if not fig and not axes: if projection == '2d': fig, axes = plt.subplots(1, 2, figsize=figsize) elif projection == '3d': fig = plt.figure(figsize=figsize) axes = [fig.add_subplot(1, 2, 1), fig.add_subplot(1, 2, 2, projection='3d')] else: raise ValueError('Unexpected value of projection keyword argument') if isket(rho): rho = ket2dm(rho) plot_fock_distribution(rho, fig=fig, ax=axes[0]) plot_wigner(rho, fig=fig, ax=axes[1], figsize=figsize, cmap=cmap, alpha_max=alpha_max, colorbar=colorbar, method=method, projection=projection) return fig, axes
def wigner_fock_distribution(rho, fig=None, axes=None, figsize=(8, 4), cmap=None, alpha_max=7.5, colorbar=False, method='iterative'): warnings.warn("Deprecated: Use plot_wigner_fock_distribution") return plot_wigner_fock_distribution(rho, fig=fig, axes=axes, figsize=figsize, cmap=cmap, alpha_max=alpha_max, colorbar=colorbar, method=method)
[docs]def plot_expectation_values(results, ylabels=[], title=None, show_legend=False, fig=None, axes=None, figsize=(8, 4)): """ Visualize the results (expectation values) for an evolution solver. `results` is assumed to be an instance of Result, or a list of Result instances. Parameters ---------- results : (list of) :class:`qutip.solver.Result` List of results objects returned by any of the QuTiP evolution solvers. ylabels : list of strings The y-axis labels. List should be of the same length as `results`. title : string The title of the figure. show_legend : bool Whether or not to show the legend. fig : a matplotlib Figure instance The Figure canvas in which the plot will be drawn. axes : a matplotlib axes instance The axes context in which the plot will be drawn. figsize : (width, height) The size of the matplotlib figure (in inches) if it is to be created (that is, if no 'fig' and 'ax' arguments are passed). Returns ------- fig, ax : tuple A tuple of the matplotlib figure and axes instances used to produce the figure. """ if not isinstance(results, list): results = [results] n_e_ops = max([len(result.expect) for result in results]) if not fig or not axes: if not figsize: figsize = (12, 3 * n_e_ops) fig, axes = plt.subplots(n_e_ops, 1, sharex=True, figsize=figsize, squeeze=False) for r_idx, result in enumerate(results): for e_idx, e in enumerate(result.expect): axes[e_idx, 0].plot(result.times, e, label="%s [%d]" % (result.solver, e_idx)) if title: fig.suptitle(title) axes[n_e_ops - 1, 0].set_xlabel("time", fontsize=12) for n in range(n_e_ops): if show_legend: axes[n, 0].legend() if ylabels: axes[n, 0].set_ylabel(ylabels[n], fontsize=12) return fig, axes
[docs]def plot_spin_distribution_2d(P, THETA, PHI, fig=None, ax=None, figsize=(8, 8)): """ Plot a spin distribution function (given as meshgrid data) with a 2D projection where the surface of the unit sphere is mapped on the unit disk. Parameters ---------- P : matrix Distribution values as a meshgrid matrix. THETA : matrix Meshgrid matrix for the theta coordinate. PHI : matrix Meshgrid matrix for the phi coordinate. fig : a matplotlib figure instance The figure canvas on which the plot will be drawn. ax : a matplotlib axis instance The axis context in which the plot will be drawn. figsize : (width, height) The size of the matplotlib figure (in inches) if it is to be created (that is, if no 'fig' and 'ax' arguments are passed). Returns ------- fig, ax : tuple A tuple of the matplotlib figure and axes instances used to produce the figure. """ if not fig or not ax: if not figsize: figsize = (8, 8) fig, ax = plt.subplots(1, 1, figsize=figsize) Y = (THETA - pi / 2) / (pi / 2) X = (pi - PHI) / pi * np.sqrt(cos(THETA - pi / 2)) if P.min() < -1e12: cmap = cm.RdBu else: cmap = cm.RdYlBu ax.pcolor(X, Y, P.real, cmap=cmap) ax.set_xlabel(r'$\varphi$', fontsize=18) ax.set_ylabel(r'$\theta$', fontsize=18) ax.set_xticks([-1, 0, 1]) ax.set_xticklabels([r'$0$', r'$\pi$', r'$2\pi$'], fontsize=18) ax.set_yticks([-1, 0, 1]) ax.set_yticklabels([r'$\pi$', r'$\pi/2$', r'$0$'], fontsize=18) return fig, ax
[docs]def plot_spin_distribution_3d(P, THETA, PHI, fig=None, ax=None, figsize=(8, 6)): """Plots a matrix of values on a sphere Parameters ---------- P : matrix Distribution values as a meshgrid matrix. THETA : matrix Meshgrid matrix for the theta coordinate. PHI : matrix Meshgrid matrix for the phi coordinate. fig : a matplotlib figure instance The figure canvas on which the plot will be drawn. ax : a matplotlib axis instance The axis context in which the plot will be drawn. figsize : (width, height) The size of the matplotlib figure (in inches) if it is to be created (that is, if no 'fig' and 'ax' arguments are passed). Returns ------- fig, ax : tuple A tuple of the matplotlib figure and axes instances used to produce the figure. """ if fig is None or ax is None: fig = plt.figure(figsize=figsize) ax = _axes3D(fig, azim=-35, elev=35) xx = sin(THETA) * cos(PHI) yy = sin(THETA) * sin(PHI) zz = cos(THETA) if P.min() < -1e12: cmap = cm.RdBu norm = mpl.colors.Normalize(-P.max(), P.max()) else: cmap = cm.RdYlBu norm = mpl.colors.Normalize(P.min(), P.max()) ax.plot_surface(xx, yy, zz, rstride=1, cstride=1, facecolors=cmap(norm(P)), linewidth=0) cax, kw = mpl.colorbar.make_axes(ax, shrink=.66, pad=.02) cb1 = mpl.colorbar.ColorbarBase(cax, cmap=cmap, norm=norm) cb1.set_label('magnitude') return fig, ax
# # Qubism and other qubistic visualizations # def complex_array_to_rgb(X, theme='light', rmax=None): """ Makes an array of complex number and converts it to an array of [r, g, b], where phase gives hue and saturation/value are given by the absolute value. Especially for use with imshow for complex plots. For more info on coloring, see: Emilia Petrisor, Visualizing complex-valued functions with Matplotlib and Mayavi https://nbviewer.ipython.org/github/empet/Math/blob/master/DomainColoring.ipynb Parameters ---------- X : array Array (of any dimension) of complex numbers. theme : 'light' (default) or 'dark' Set coloring theme for mapping complex values into colors. rmax : float Maximal abs value for color normalization. If None (default), uses np.abs(X).max(). Returns ------- Y : array Array of colors (of shape X.shape + (3,)). """ absmax = rmax or np.abs(X).max() if absmax == 0.: absmax = 1. Y = np.zeros(X.shape + (3,), dtype='float') Y[..., 0] = np.angle(X) / (2 * pi) % 1 if theme == 'light': Y[..., 1] = np.clip(np.abs(X) / absmax, 0, 1) Y[..., 2] = 1 elif theme == 'dark': Y[..., 1] = 1 Y[..., 2] = np.clip(np.abs(X) / absmax, 0, 1) Y = mpl.colors.hsv_to_rgb(Y) return Y def _index_to_sequence(i, dim_list): """ For a matrix entry with index i it returns state it corresponds to. In particular, for dim_list=[2]*n it returns i written as a binary number. Parameters ---------- i : int Index in a matrix. dim_list : list of int List of dimensions of consecutive particles. Returns ------- seq : list List of coordinates for each particle. """ res = [] j = i for d in reversed(dim_list): j, s = divmod(j, d) res.append(s) return list(reversed(res)) def _sequence_to_index(seq, dim_list): """ Inverse of _index_to_sequence. Parameters ---------- seq : list of ints List of coordinates for each particle. dim_list : list of int List of dimensions of consecutive particles. Returns ------- i : list Index in a matrix. """ i = 0 for s, d in zip(seq, dim_list): i *= d i += s return i def _to_qubism_index_pair(i, dim_list, how='pairs'): """ For a matrix entry with index i it returns x, y coordinates in qubism mapping. Parameters ---------- i : int Index in a matrix. dim_list : list of int List of dimensions of consecutive particles. how : 'pairs' ('default'), 'pairs_skewed' or 'before_after' Type of qubistic plot. Returns ------- x, y : tuple of ints List of coordinates for each particle. """ seq = _index_to_sequence(i, dim_list) if how == 'pairs': y = _sequence_to_index(seq[::2], dim_list[::2]) x = _sequence_to_index(seq[1::2], dim_list[1::2]) elif how == 'pairs_skewed': dim_list2 = dim_list[::2] y = _sequence_to_index(seq[::2], dim_list2) seq2 = [(b - a) % d for a, b, d in zip(seq[::2], seq[1::2], dim_list2)] x = _sequence_to_index(seq2, dim_list2) elif how == 'before_after': # https://en.wikipedia.org/wiki/File:Ising-tartan.png n = len(dim_list) y = _sequence_to_index(reversed(seq[:(n // 2)]), reversed(dim_list[:(n // 2)])) x = _sequence_to_index(seq[(n // 2):], dim_list[(n // 2):]) else: raise Exception("No such 'how'.") return x, y def _sequence_to_latex(seq, style='ket'): """ For a sequence of particle states generate LaTeX code. Parameters ---------- seq : list of ints List of coordinates for each particle. style : 'ket' (default), 'bra' or 'bare' Style of LaTeX (i.e. |01> or <01| or 01, respectively). Returns ------- latex : str LaTeX output. """ if style == 'ket': latex = "$\\left|{0}\\right\\rangle$" elif style == 'bra': latex = "$\\left\\langle{0}\\right|$" elif style == 'bare': latex = "${0}$" else: raise Exception("No such style.") return latex.format("".join(map(str, seq)))
[docs]def plot_qubism(ket, theme='light', how='pairs', grid_iteration=1, legend_iteration=0, fig=None, ax=None, figsize=(6, 6)): """ Qubism plot for pure states of many qudits. Works best for spin chains, especially with even number of particles of the same dimension. Allows to see entanglement between first 2k particles and the rest. Parameters ---------- ket : Qobj Pure state for plotting. theme : 'light' (default) or 'dark' Set coloring theme for mapping complex values into colors. See: complex_array_to_rgb. how : 'pairs' (default), 'pairs_skewed' or 'before_after' Type of Qubism plotting. Options: - 'pairs' - typical coordinates, - 'pairs_skewed' - for ferromagnetic/antriferromagnetic plots, - 'before_after' - related to Schmidt plot (see also: plot_schmidt). grid_iteration : int (default 1) Helper lines to be drawn on plot. Show tiles for 2*grid_iteration particles vs all others. legend_iteration : int (default 0) or 'grid_iteration' or 'all' Show labels for first ``2*legend_iteration`` particles. Option 'grid_iteration' sets the same number of particles as for grid_iteration. Option 'all' makes label for all particles. Typically it should be 0, 1, 2 or perhaps 3. fig : a matplotlib figure instance The figure canvas on which the plot will be drawn. ax : a matplotlib axis instance The axis context in which the plot will be drawn. figsize : (width, height) The size of the matplotlib figure (in inches) if it is to be created (that is, if no 'fig' and 'ax' arguments are passed). Returns ------- fig, ax : tuple A tuple of the matplotlib figure and axes instances used to produce the figure. Notes ----- See also [1]_. References ---------- .. [1] J. Rodriguez-Laguna, P. Migdal, M. Ibanez Berganza, M. Lewenstein and G. Sierra, *Qubism: self-similar visualization of many-body wavefunctions*, `New J. Phys. 14 053028 <https://dx.doi.org/10.1088/1367-2630/14/5/053028>`_, arXiv:1112.3560 (2012), open access. """ if not isket(ket): raise Exception("Qubism works only for pure states, i.e. kets.") # add for dm? (perhaps a separate function, plot_qubism_dm) if not fig and not ax: fig, ax = plt.subplots(1, 1, figsize=figsize) dim_list = ket.dims[0] n = len(dim_list) # for odd number of particles - pixels are rectangular if n % 2 == 1: ket = tensor(ket, Qobj([1] * dim_list[-1])) dim_list = ket.dims[0] n += 1 ketdata = ket.full() if how == 'pairs': dim_list_y = dim_list[::2] dim_list_x = dim_list[1::2] elif how == 'pairs_skewed': dim_list_y = dim_list[::2] dim_list_x = dim_list[1::2] if dim_list_x != dim_list_y: raise Exception("For 'pairs_skewed' pairs " + "of dimensions need to be the same.") elif how == 'before_after': dim_list_y = list(reversed(dim_list[:(n // 2)])) dim_list_x = dim_list[(n // 2):] else: raise Exception("No such 'how'.") size_x = np.prod(dim_list_x) size_y = np.prod(dim_list_y) qub = np.zeros([size_x, size_y], dtype=complex) for i in range(ketdata.size): qub[_to_qubism_index_pair(i, dim_list, how=how)] = ketdata[i, 0] qub = qub.transpose() quadrants_x = np.prod(dim_list_x[:grid_iteration]) quadrants_y = np.prod(dim_list_y[:grid_iteration]) ticks_x = [size_x // quadrants_x * i for i in range(1, quadrants_x)] ticks_y = [size_y // quadrants_y * i for i in range(1, quadrants_y)] ax.set_xticks(ticks_x) ax.set_xticklabels([""] * (quadrants_x - 1)) ax.set_yticks(ticks_y) ax.set_yticklabels([""] * (quadrants_y - 1)) theme2color_of_lines = {'light': '#000000', 'dark': '#FFFFFF'} ax.grid(True, color=theme2color_of_lines[theme]) ax.imshow(complex_array_to_rgb(qub, theme=theme), interpolation="none", extent=(0, size_x, 0, size_y)) if legend_iteration == 'all': label_n = n // 2 elif legend_iteration == 'grid_iteration': label_n = grid_iteration else: try: label_n = int(legend_iteration) except: raise Exception("No such option for legend_iteration keyword " + "argument. Use 'all', 'grid_iteration' or an " + "integer.") if label_n: if how == 'before_after': dim_list_small = list(reversed(dim_list_y[-label_n:])) \ + dim_list_x[:label_n] else: dim_list_small = [] for j in range(label_n): dim_list_small.append(dim_list_y[j]) dim_list_small.append(dim_list_x[j]) scale_x = float(size_x) / np.prod(dim_list_x[:label_n]) shift_x = 0.5 * scale_x scale_y = float(size_y) / np.prod(dim_list_y[:label_n]) shift_y = 0.5 * scale_y bbox = ax.get_window_extent().transformed( fig.dpi_scale_trans.inverted()) fontsize = 35 * bbox.width / np.prod(dim_list_x[:label_n]) / label_n opts = {'fontsize': fontsize, 'color': theme2color_of_lines[theme], 'horizontalalignment': 'center', 'verticalalignment': 'center'} for i in range(np.prod(dim_list_small)): x, y = _to_qubism_index_pair(i, dim_list_small, how=how) seq = _index_to_sequence(i, dim_list=dim_list_small) ax.text(scale_x * x + shift_x, size_y - (scale_y * y + shift_y), _sequence_to_latex(seq), **opts) return fig, ax
[docs]def plot_schmidt(ket, splitting=None, labels_iteration=(3, 2), theme='light', fig=None, ax=None, figsize=(6, 6)): """ Plotting scheme related to Schmidt decomposition. Converts a state into a matrix (A_ij -> A_i^j), where rows are first particles and columns - last. See also: plot_qubism with how='before_after' for a similar plot. Parameters ---------- ket : Qobj Pure state for plotting. splitting : int Plot for a number of first particles versus the rest. If not given, it is (number of particles + 1) // 2. theme : 'light' (default) or 'dark' Set coloring theme for mapping complex values into colors. See: complex_array_to_rgb. labels_iteration : int or pair of ints (default (3,2)) Number of particles to be shown as tick labels, for first (vertical) and last (horizontal) particles, respectively. fig : a matplotlib figure instance The figure canvas on which the plot will be drawn. ax : a matplotlib axis instance The axis context in which the plot will be drawn. figsize : (width, height) The size of the matplotlib figure (in inches) if it is to be created (that is, if no 'fig' and 'ax' arguments are passed). Returns ------- fig, ax : tuple A tuple of the matplotlib figure and axes instances used to produce the figure. """ if not isket(ket): raise Exception("Schmidt plot works only for pure states, i.e. kets.") if not fig and not ax: fig, ax = plt.subplots(1, 1, figsize=figsize) dim_list = ket.dims[0] if splitting is None: splitting = (len(dim_list) + 1) // 2 if isinstance(labels_iteration, int): labels_iteration = labels_iteration, labels_iteration ketdata = ket.full() dim_list_y = dim_list[:splitting] dim_list_x = dim_list[splitting:] size_x = np.prod(dim_list_x) size_y = np.prod(dim_list_y) ketdata = ketdata.reshape((size_y, size_x)) dim_list_small_x = dim_list_x[:labels_iteration[1]] dim_list_small_y = dim_list_y[:labels_iteration[0]] quadrants_x = np.prod(dim_list_small_x) quadrants_y = np.prod(dim_list_small_y) ticks_x = [size_x / quadrants_x * (i + 0.5) for i in range(quadrants_x)] ticks_y = [size_y / quadrants_y * (quadrants_y - i - 0.5) for i in range(quadrants_y)] labels_x = [_sequence_to_latex(_index_to_sequence(i*size_x // quadrants_x, dim_list=dim_list_x)) for i in range(quadrants_x)] labels_y = [_sequence_to_latex(_index_to_sequence(i*size_y // quadrants_y, dim_list=dim_list_y)) for i in range(quadrants_y)] ax.set_xticks(ticks_x) ax.set_xticklabels(labels_x) ax.set_yticks(ticks_y) ax.set_yticklabels(labels_y) ax.set_xlabel("last particles") ax.set_ylabel("first particles") ax.imshow(complex_array_to_rgb(ketdata, theme=theme), interpolation="none", extent=(0, size_x, 0, size_y)) return fig, ax