Skip to content

Commit

Permalink
add test structure
Browse files Browse the repository at this point in the history
  • Loading branch information
juanitorduz committed Aug 17, 2023
1 parent 90710e3 commit ef2cc3a
Showing 1 changed file with 24 additions and 0 deletions.
24 changes: 24 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from pymc.logprob.basic import conditional_logp, transformed_conditional_logp
from pymc.logprob.transforms import IntervalTransform
from pymc.model import Point, ValueGradFunction, modelcontext
from pymc.model_graph import model_to_graphviz, model_to_networkx
from pymc.util import _FutureWarningValidatingScratchpad
from pymc.variational.minibatch_rv import MinibatchRandomVariable
from tests.models import simple_model
Expand Down Expand Up @@ -1653,3 +1654,26 @@ def test_model_logp_fast_compile():

with pytensor.config.change_flags(mode="FAST_COMPILE"):
assert m.point_logps() == {"a": -1.5}


class TestModelGraphs:
@staticmethod
def school_model() -> pm.Model:
J = 8
y = np.array([28, 8, -3, 7, -1, 1, 18, 12])
sigma = np.array([15, 10, 16, 11, 9, 11, 10, 18])
with pm.Model() as schools:
eta = pm.Normal("eta", 0, 1, shape=J)
mu = pm.Normal("mu", 0, sigma=1e6)
tau = pm.HalfCauchy("tau", 25)
theta = mu + tau * eta
obs = pm.Normal("obs", theta, sigma=sigma, observed=y)
return schools

def test_graphviz(self) -> None:
model = self.school_model()
assert model.graphviz()

def test_networkx(self) -> None:
model = self.school_model()
model_to_networkx(model=model) == model.networkx()

0 comments on commit ef2cc3a

Please sign in to comment.