Plotting mathematical functions in VisIt

VisIt is most often used to view the output of simulation codes but it can be used for a variety of other purposes, including visualization of mathematical functions. The mathematical function data can be saved to any one of the supported file formats that VisIt can read and then the data can be read and plotted by VisIt.

PlotFunction script

If you want a more interactive way of visualizing mathematical functions, you can use the PlotFunction function in the following script. The PlotFunction function takes as input a string containing the mathematical function that you want the Python interpreter to execute. PlotFunction then evaluates that function at many points on a mesh and saves the data is to a VTK file. PlotFunction then tells VisIt to plot the data from the VTK file in the most appropriate manner (i.e. 1-D data plots as curves, 2-D data plots as surfaces, etc.).

from math import *
import random

fileBaseIndex = 0

def GetTempFileBase():
    global fileBaseIndex
    base = "/var/tmp/plotdata%04d" % fileBaseIndex
    fileBaseIndex += 1
    return base

def PlotFunction(func, xmin=0., xmax=1., ymin=0., ymax=1., zmin=0., zmax=1., \
                 nsteps=100, showMesh=True, color=(255,0,0,255), \
                 annotColor=(255,0,0,255)):
    # func = "f(x) = log(x)"
    leftP = func.find("(")
    rightP = func.find(")")
    equals = func.find("=")
    if leftP == 0 or rightP == 0 or equals == 0:
        raise RuntimeError

    if rightP > equals or leftP > equals or leftP > rightP:
        raise RuntimeError

    # Figure out the argument list for the new function.
    args = func[leftP:rightP]
    vars = args.split(",")
    nargs = len(vars)
    for i in range(nargs):
        v = vars[i]
        if 'x' in v:
            vars[i] = "x"
        if 'X' in v:
            vars[i] = "x"
        if 'y' in v:
            vars[i] = "y"
        if 'Y' in v:
            vars[i] = "y"
        if 'z' in v:
            vars[i] = "z"
        if 'Z' in v:
            vars[i] = "z"

    # Figure out the function body for the new function.
    funcBody = func[equals+1:]
    while len(funcBody) > 0 and funcBody[0] == ' ':
        funcBody = funcBody[1:]
    if len(funcBody) < 1:
        raise RuntimeError

    if nargs == 1:
        if vars[0] not in ("x", "y"):
            print "1D curves can only use x or y as function parameters"
            raise RuntimeError

        # Create the data file that we'll use for the plot.
        datafile = GetTempFileBase() + ".curve"
        f = open(datafile, "w")
        if vars[0] == "x":
            f.write("# Fx\n")
            for i in range(nsteps):
                t = float(i) / float(nsteps-1)
                x = t*xmax + (1.-t)*xmin
                try:
                    fx = eval(funcBody)
                except ZeroDivisionError:
                    print "Divide by zero f(%g)" % x
                    fx = 0.
                except OverflowError:
                    print "Math range error f(%g)" % x
                    continue
                f.write("%g %g\n" % (x, fx))
        elif vars[0] == "y":
            f.write("# Fy\n")
            for i in range(nsteps):
                t = float(i) / float(nsteps-1)
                y = t*ymax + (1.-t)*ymin
                try:
                    fy = eval(funcBody)
                except ZeroDivisionError:
                    print "Divide by zero f(%g)" % y
                    fy = 0.
                except OverflowError:
                    print "Math range error f(%g)" % y
                    continue
                f.write("%g %g\n" % (fy, y))
        f.close()

        # Create a curve plot.
        OpenDatabase(datafile)
        AddPlot("Curve", "F" + vars[0])
        c = CurveAttributes()
        c.showPoints = int(showMesh)
        c.color = color
        SetPlotOptions(c)
        DrawPlots()
        ResetView()
    elif nargs == 2:
        if vars[0] != "x" or vars[1] != "y":
            print "2D functions can only use x or y as function parameters and " \
                   "they must be in the order: x,y."
            raise RuntimeError

        # Create the VTK data file that we'll use for the plot.
        datafile = GetTempFileBase() + ".vtk"
        f = open(datafile, "w")

        xcoords = ""
        for i in range(nsteps):
            t = float(i) / float(nsteps-1)
            x = t*xmax + (1.-t)*xmin
            xcoords = xcoords + " %g" % x
        ycoords = ""
        for i in range(nsteps):
            t = float(i) / float(nsteps-1)
            y = t*ymax + (1.-t)*ymin
            ycoords = ycoords + " %g" % y

        # Determine whether we have vector data.
        x = xmin
        y = ymin
        fxy = eval(funcBody)
        vectorData = type(fxy) in (type(()), type([]))

        # Generate the data file.
        f.write("# vtk DataFile Version 3.0\n")
        f.write("vtk output\n")
        f.write("ASCII\n")
        f.write("DATASET RECTILINEAR_GRID\n")
        f.write("DIMENSIONS %d %d 1\n" % (nsteps, nsteps))
        f.write("X_COORDINATES %d float\n" % nsteps)
        f.write(xcoords + "\n")
        f.write("Y_COORDINATES %d float\n" % nsteps)
        f.write(ycoords + "\n")
        f.write("Z_COORDINATES 1 float\n")
        f.write("0.000\n")
        f.write("\n")
        f.write("POINT_DATA %d\n" % (nsteps * nsteps))
        if vectorData:
            f.write("VECTORS Fxy float\n")
        else:
            f.write("SCALARS Fxy float 1\n")
            f.write("LOOKUP_TABLE Fxy\n")
        firstZ = True
        for j in range(nsteps):
            ty = float(j) / float(nsteps-1)
            y = ty*ymax + (1.-ty)*ymin
            for i in range(nsteps):
                tx = float(i) / float(nsteps-1)
                x = tx*xmax + (1.-tx)*xmin
                try:
                    fxy = eval(funcBody)
                except ZeroDivisionError:
                    print "Divide by zero f(%g,%g)" % (x,y)
                    if vectorData:
                        fxy = (0., 0., 0.)
                    else:
                        fxy = 0.
                except OverflowError:
                    print "Math range error f(%g,%g)" % (x,y)
                    if vectorData:
                        fxy = (0., 0., 0.)
                    else:
                        fxy = 0.
                if vectorData:
                    if type(fxy) != type(()) and type(fxy) != type([]):
                        print "The function ",func," did not produce a vector."
                        raise RuntimeError
                    else:
                        if len(fxy) == 2:
                            f.write("%g %g 0.\n" % fxy)
                        elif len(fxy) == 3:
                            f.write("%g %g %g\n" % fxy)
                        else:
                            print "The number of components of a vector must be 2 or 3."
                            raise RuntimeError
                else:
                    f.write("%g\n" % fxy)
                    # Compute zmin, zmax
                    if firstZ:
                        zmin = fxy
                        zmax = fxy
                        firstZ = False
                    else:
                        if fxy < zmin:
                            zmin = fxy
                        if fxy > zmax:
                            zmax = fxy
        f.close()

        # Do the appropriate plot
        OpenDatabase(datafile)
        if vectorData:
            if showMesh:
                AddPlot("Mesh", "mesh")
                m = MeshAttributes()
                m.legendFlag = 0
                SetPlotOptions(m)
            # Create a vector plot.
            AddPlot("Vector", "Fxy")
            v = VectorAttributes()
            v.useStride = 1
            v.vectorColor = color
            #v.colorByMag = 0
            SetPlotOptions(v)
        else:
            # Create a Surface plot.
            AddPlot("Pseudocolor", "Fxy")
            plots = [0]
            if showMesh:
                AddPlot("Mesh", "mesh")
                m = MeshAttributes()
                m.meshColor = color
                m.foregroundFlag = 0
                m.legendFlag = 0
                SetPlotOptions(m)
                plots += [1]
            SetActivePlots(tuple(plots))
            AddOperator("Elevate")
            elevAtts = ElevateAttributes()
            elevAtts.variable = "Fxy"
            SetOperatorOptions(elevAtts)
            SetActivePlots(0)
        DrawPlots()
        ResetView()

        v = GetView3D()
        v.viewNormal = (0.766658, 0.500028, 0.402749)
        v.focus = ((xmax+xmin) / 2., (ymax+ymin) / 2., (zmax+zmin) / 2.)
        v.viewUp = (-0.318885, -0.247904, 0.914798)
        SetView3D(v)
    elif nargs == 3:
        if vars[0] != "x" or vars[1] != "y" or vars[2] != "z":
            print "3D functions can only use x or y or z as function parameters and " \
                   "they must be in the order: x,y,z."
            raise RuntimeError

        # Create the VTK data file that we'll use for the plot.
        datafile = GetTempFileBase() + ".vtk"
        f = open(datafile, "w")

        xcoords = ""
        ycoords = ""
        zcoords = ""
        for i in range(nsteps):
            t = float(i) / float(nsteps-1)
            x = t*xmax + (1.-t)*xmin
            y = t*ymax + (1.-t)*ymin
            z = t*zmax + (1.-t)*zmin
            xcoords += " %g" % x
            ycoords += " %g" % y
            zcoords += " %g" % z
            if i > 0 and i != nsteps-1 and i % 9 == 0:
                xcoords += "\n"
                ycoords += "\n"
                zcoords += "\n"
        if xcoords[-1] == '\n':
            xcoords = xcoords[:-1]
            ycoords = ycoords[:-1]
            zcoords = zcoords[:-1]

        # Determine whether we have vector data.
        x = xmin
        y = ymin
        z = ymin
        fxyz = eval(funcBody)
        vectorData = type(fxyz) in (type(()), type([]))

        # Generate the data file.
        f.write("# vtk DataFile Version 3.0\n")
        f.write("vtk output\n")
        f.write("ASCII\n")
        f.write("DATASET RECTILINEAR_GRID\n")
        f.write("DIMENSIONS %d %d %d\n" % (nsteps, nsteps, nsteps))
        f.write("X_COORDINATES %d float\n" % nsteps)
        f.write(xcoords + "\n")
        f.write("Y_COORDINATES %d float\n" % nsteps)
        f.write(ycoords + "\n")
        f.write("Z_COORDINATES %d float\n" % nsteps)
        f.write(zcoords + "\n")
        f.write("\n")
        f.write("POINT_DATA %d\n" % (nsteps * nsteps * nsteps))
        if vectorData:
            f.write("VECTORS Fxyz float\n")
        else:
            f.write("SCALARS Fxyz float 1\n")
            f.write("LOOKUP_TABLE Fxyz\n")
        for k in range(nsteps):
            tz = float(k) / float(nsteps-1)
            z = tz*zmax + (1.-tz)*zmin
            for j in range(nsteps):
                ty = float(j) / float(nsteps-1)
                y = ty*ymax + (1.-ty)*ymin
                for i in range(nsteps):
                    tx = float(i) / float(nsteps-1)
                    x = tx*xmax + (1.-tx)*xmin
                    try:
                        fxyz = eval(funcBody)
                    except ZeroDivisionError:
                        print "Divide by zero f(%g,%g,%g)" % (x,y,z)
                        if vectorData:
                            fxyz = (0., 0., 0.)
                        else:
                            fxyz = 0.
                    except OverflowError:
                        print "Math range error f(%g,%g,%g)" % (x,y,z)
                        if vectorData:
                            fxyz = (0., 0., 0.)
                        else:
                            fxyz = 0.
                    if vectorData:
                        if type(fxyz) != type(()) and type(fxyz) != type([]):
                            print "The function ",func," did not produce a vector."
                            raise RuntimeError
                        else:
                            if len(fxyz) == 2:
                                f.write("%g %g 0.\n" % fxyz)
                            elif len(fxyz) == 3:
                                f.write("%g %g %g\n" % fxyz)
                            else:
                                print "The number of components of a vector must be 2 or 3."
                                raise RuntimeError
                    else:
                        f.write("%g\n" % fxyz)
        f.close()

        OpenDatabase(datafile)
        if showMesh:
            AddPlot("Mesh", "mesh")
            m = MeshAttributes()
            m.legendFlag = 0
            m.meshColor = color
            m.foregroundFlag = 0
            m.opaqueMode = m.Off
            SetPlotOptions(m)
        if vectorData:
            # Create a vector plot.
            AddPlot("Vector", "Fxyz")
            v = VectorAttributes()
            v.useStride = 1
            v.vectorColor = color
            #v.colorByMag = 0
            SetPlotOptions(v)
        else:
            # Create a Pseudocolor plot.
            AddPlot("Pseudocolor", "Fxyz")
        DrawPlots()
        ResetView()

        v = GetView3D()
        v.viewNormal = (0.766658, 0.500028, 0.402749)
        v.focus = ((xmax+xmin) / 2., (ymax+ymin) / 2., (zmax+zmin) / 2.)
        v.viewUp = (-0.318885, -0.247904, 0.914798)
        SetView3D(v)

    # Turn off the database annotation.
    a = GetAnnotationAttributes()
    a.databaseInfoFlag = 0
    SetAnnotationAttributes(a)

    # Create an annotation object for the title.
    newName = "Text2D_" + str(random.randrange(1000))
    title = CreateAnnotationObject("Text2D", newName)
    title.text = func
    title.useForegroundForTextColor = 0
    title.textColor = annotColor
    title.width = float(len(func)) * 0.01

    # Change the layout of the annotations so that we can keep them organized.
    legendY = 0.96
    aaNames = GetAnnotationObjectNames()
    for aaName in aaNames:    
        try:
            a = GetAnnotationObject(aaName)
            if type(a).__name__ == 'Text2DObject':
                a.SetPosition(0.01, legendY)
                legendY -= 0.03
            index += 1
        except:
            pass

Usage

The PlotFunction script takes many arguments but only the func argument is required. The func argument contains a string composed of 2 pieces: a function descriptor f(x), f(x,y), f(x,y,z), and a function definition, which must be a valid Python expression involving the variables x,y,z.

Argument Description Default value
func The function to be evaluated none
xmin The minimum value to use in the X-dimension 0.
xmax The maximum value to use in the X-dimension 1.
ymin The minimum value to use in the Y-dimension; only used when the function involves y. 0.
ymax The maximum value to use in the Y-dimension; only used when the function involves y. 1.
zmin The minimum value to use in the Z-dimension; only used when the function involves z. The zmin is calculated dynamically for 2-d functions. 0.
zmax The maximum value to use in the Z-dimension; only used when the function involves z. The zmax is calculated dynamcially for 2-d functions. 1.
nsteps The number of samples to take in each dimension 100
showMesh Whether the mesh should be shown for 2D, 3D data. 1
color The mesh color (255,0,0,255)
annotColor The annotation text color (255,0,0,255)

Examples

Here are some examples that show how to use the PlotFunction function.

# 1D functions
PlotFunction("f(x) = x * x", xmin = 0., xmax = 5.)
PlotFunction("f(x) = x * x + 2.", xmin = -5.0, xmax = 5., \
             color=(0,255,0,255), annotColor=(0,255,0,255))
PlotFunction("f(y) = 0.25 * y * y", ymin = -5.0, ymax = 5., \
             color=(0,255,255,255), annotColor=(0,255,255,255))

SetCloneWindowOnFirstRef(0)
# A 2D scalar function
AddWindow()
PlotFunction("f(x,y) = -10. * sqrt(abs(x * y))", color=(0,0,0,255), showMesh=True, \
             xmin=-5., xmax=5., ymin=-5.0, ymax = 5., nsteps=20)

# A 2D scalar function
AddWindow()
PlotFunction("f(x,y) = (x*x - y*y) / (x*x + y*y)", color=(0,0,0,255), showMesh=True, \
             xmin=-5., xmax=5., ymin=-5.0, ymax = 5., nsteps=50)

# Show a 2D vector function.
AddWindow()
PlotFunction("f(x,y) = (cos(x), sin(y))", color=(0,0,0,255), showMesh=False, \
            xmin=-5., xmax=5., ymin=-5.0, ymax = 5., nsteps=50)

# Show a 3D function
AddWindow()
PlotFunction("f(x,y,z) = ((x*x - y*y) / (x*x + y*y)) * sqrt(x*x + y*y + z*z)", showMesh=True, \
             xmin=-5., xmax=5., ymin=-5.0, ymax = 5., zmin=-5., zmax=5., nsteps=30)

Example with picture

PlotFunction("f(x,y) = cos(sqrt(x*x + y*y))", color=(0,0,0,255), showMesh=True, \
             xmin=-15., xmax=15., ymin=-15.0, ymax = 15., nsteps=100)

Function.png