A quick demo of how we can represent a computational graph with objects

In [2]:
class Node:
    def value(self):
        pass

    def __add__(self, x):
        if not isinstance(x, Node):
            x = ConstantNode(float(x))
        return SumNode(self, x)

    def __sub__(self, x):
        if not isinstance(x, Node):
            x = ConstantNode(float(x))
        return DiffNode(self, x)
    
    def __mul__(self, x):
        if not isinstance(x, Node):
            x = ConstantNode(float(x))
        return ProdNode(self, x)

    def __radd__(self, x):
        return self.__add__(x)
 
    def __rsub__(self, x):
        return self.__sub__(x)

    def __rmul__(self, x):
        return self.__mul__(x)

class ParameterNode(Node):
    def __init__(self, param_value):
        self.param_value = param_value

    def value(self):
        return self.param_value
    
    def set_value(self, pv):
        self.param_value = pv

    
class ConstantNode(Node): # a node for a constant value
    def __init__(self, const_value):
        self.const_value = const_value

    def value(self):
        return self.const_value

    
class SumNode(Node): # x + y
    def __init__(self, arg1, arg2): 
        self.arg1 = arg1
        self.arg2 = arg2

    def value(self):
        arg1_value = self.arg1.value()
        arg2_value = self.arg2.value()
        return arg1_value + arg2_value


class DiffNode(Node): # x - y
    def __init__(self, arg1, arg2): 
        self.arg1 = arg1
        self.arg2 = arg2

    def value(self):
        arg1_value = self.arg1.value()
        arg2_value = self.arg2.value()
        return arg1_value - arg2_value


class ProdNode(Node): # x * y
    def __init__(self, arg1, arg2): 
        self.arg1 = arg1
        self.arg2 = arg2

    def value(self):
        arg1_value = self.arg1.value()
        arg2_value = self.arg2.value()
        return arg1_value * arg2_value
In [14]:
x = ParameterNode(1.0)
In [15]:
y = 3 * (x + 1)
In [ ]:
 
In [22]:
y.value()
Out[22]:
6.0

We don't need to re-create the computational graph to compute the function at a different parameter.

In [24]:
x.set_value(3.0)
print(x.value(), " -> ", y.value())
3.0  ->  12.0
In [25]:
u = 2 * (x * x) + x - 1
In [26]:
u.value()
Out[26]:
20.0
In [27]:
x.set_value(0.0)
u.value()
Out[27]:
-1.0
In [ ]:
 
In [ ]: