Package solvcon :: Package tests :: Module test_solver
[hide private]
[frames] | no frames]

Source Code for Module solvcon.tests.test_solver

  1  import os 
  2  from unittest import TestCase 
  3  from ..testing import get_blk_from_sample_neu 
  4  from ..solver import BaseSolver, BlockSolver 
  5  from ..testing import TestingSolver 
6 7 -class CustomBaseSolver(BaseSolver):
8 - def __init__(self, **kw):
9 kw['neq'] = 1 10 super(CustomBaseSolver, self).__init__(**kw)
11 - def bind(self):
12 super(CustomBaseSolver, self).bind() 13 self.val = 'bind'
14
15 -class CustomBlockSolver(TestingSolver):
16 MESG_FILENAME_DEFAULT = os.devnull
17
18 -class TestBase(TestCase):
19 - def test_base(self):
20 self.assertRaises(KeyError, BaseSolver) 21 bsvr = BaseSolver(neq=1) 22 self.assertEqual(getattr(bsvr, 'val', None), None) 23 bsvr.bind() 24 self.assertEqual(getattr(bsvr, 'val', None), None)
25
26 - def test_inheritance(self):
27 svr = CustomBaseSolver() 28 self.assertEqual(getattr(svr, 'val', None), None) 29 svr.bind() 30 self.assertEqual(svr.val, 'bind')
31
32 -class TestFpdtype(TestCase):
33 - def test_fp(self):
34 from ..dependency import str_of 35 from ..conf import env 36 bsvr = BaseSolver(neq=1) 37 self.assertEqual(bsvr.fpdtype, env.fpdtype) 38 self.assertEqual(bsvr.fpdtypestr, str_of(env.fpdtype))
39
40 -class TestBlock(TestCase):
41 neq = 1 42
43 - def test_simplex(self):
44 from ..testing import get_blk_from_sample_neu, get_blk_from_oblique_neu 45 svr = CustomBlockSolver(get_blk_from_oblique_neu(), 46 neq=self.neq, enable_mesg=True) 47 self.assertTrue(svr.all_simplex) 48 svr = CustomBlockSolver(get_blk_from_sample_neu(), 49 neq=self.neq, enable_mesg=True) 50 self.assertFalse(svr.all_simplex)
51
52 - def test_incenter(self):
53 from ..testing import get_blk_from_sample_neu, get_blk_from_oblique_neu 54 svr = CustomBlockSolver(get_blk_from_oblique_neu(use_incenter=True), 55 neq=self.neq, enable_mesg=True) 56 self.assertTrue(svr.use_incenter) 57 svr = CustomBlockSolver(get_blk_from_sample_neu(use_incenter=False), 58 neq=self.neq, enable_mesg=True) 59 self.assertFalse(svr.use_incenter)
60 61 @staticmethod
62 - def _get_block():
64 65 @classmethod
66 - def _get_solver(cls, init=True):
67 import warnings 68 svr = CustomBlockSolver(cls._get_block(), neq=cls.neq, enable_mesg=True) 69 if init: 70 warnings.simplefilter("ignore") 71 svr.bind() 72 svr.init() 73 warnings.resetwarnings() 74 return svr
75
76 - def test_debug(self):
77 import sys 78 from cStringIO import StringIO 79 CustomBlockSolver.MESG_FILENAME_DEFAULT = 'sys.stdout' 80 stdout = sys.stdout 81 sys.stdout = StringIO() 82 svr = self._get_solver(init=True) 83 svr.mesg('test message') 84 self.assertEqual(sys.stdout.getvalue(), 'test message') 85 sys.stdout = stdout 86 CustomBlockSolver.MESG_FILENAME_DEFAULT = os.devnull
87
88 - def test_create(self):
89 svr = self._get_solver() 90 self.assertTrue(svr)
91
92 - def test_bound_full(self):
93 svr = self._get_solver() 94 svr.bind() 95 self.assertTrue(svr.is_bound) 96 self.assertFalse(svr.is_unbound)
97
98 - def test_unbound_full(self):
99 svr = self._get_solver() 100 svr.unbind() 101 self.assertFalse(svr.is_bound) 102 self.assertTrue(svr.is_unbound)
103
104 - def test_neq(self):
105 svr = self._get_solver() 106 self.assertEqual(svr.neq, self.neq)
107
108 - def test_blkn(self):
109 svr = self._get_solver() 110 self.assertEqual(svr.svrn, None)
111
112 - def test_metric(self):
113 svr = self._get_solver() 114 self.assertEqual(len(svr.cecnd.shape), 3) 115 self.assertEqual(svr.cecnd.shape[0], svr.ncell+svr.ngstcell) 116 self.assertEqual(svr.cecnd.shape[1], svr.CLMFC+1) 117 self.assertEqual(svr.cecnd.shape[2], svr.ndim) 118 self.assertEqual(len(svr.cevol.shape), 2) 119 self.assertEqual(svr.cevol.shape[0], svr.ncell+svr.ngstcell) 120 self.assertEqual(svr.cevol.shape[1], svr.CLMFC+1)
121
122 - def test_solution(self):
123 svr = self._get_solver() 124 self.assertEqual(len(svr.sol.shape), 2) 125 self.assertEqual(svr.sol.shape[0], svr.soln.shape[0]) 126 self.assertEqual(svr.sol.shape[1], svr.soln.shape[1]) 127 self.assertEqual(svr.sol.shape[0], svr.ncell+svr.ngstcell) 128 self.assertEqual(svr.sol.shape[1], svr.neq) 129 self.assertEqual(len(svr.dsol.shape), 3) 130 self.assertEqual(svr.dsol.shape[0], svr.dsoln.shape[0]) 131 self.assertEqual(svr.dsol.shape[1], svr.dsoln.shape[1]) 132 self.assertEqual(svr.dsol.shape[2], svr.dsoln.shape[2]) 133 self.assertEqual(svr.dsol.shape[0], svr.ncell+svr.ngstcell) 134 self.assertEqual(svr.dsol.shape[1], svr.neq) 135 self.assertEqual(svr.dsol.shape[2], svr.ndim)
136 137 time = 0.0 138 time_increment = 1.0 139 nsteps = 10 140
141 - def _run_solver(self, time, time_increment, nsteps):
142 # initialize. 143 svr = self._get_solver() 144 svr.soln.fill(0.0) 145 svr.dsoln.fill(0.0) 146 # run. 147 svr.march(time, time_increment, nsteps) 148 return svr
149
150 - def test_soln(self):
151 from numpy import zeros 152 # run. 153 svr = self._run_solver(self.time, self.time_increment, self.nsteps) 154 ngstcell = svr.ngstcell 155 # get result. 156 soln = svr.soln[ngstcell:,0] 157 # calculate reference 158 clvol = zeros(soln.shape, dtype=soln.dtype) 159 for iistep in range(self.nsteps*2): 160 clvol += svr.clvol[ngstcell:]*self.time_increment/2 161 # compare. 162 self.assertTrue((soln==clvol).all())
163
164 - def test_dsoln(self):
165 from numpy import zeros 166 # run. 167 svr = self._run_solver(self.time, self.time_increment, self.nsteps) 168 ngstcell = svr.ngstcell 169 # get result. 170 dsoln = svr.dsoln[ngstcell:,0,:] 171 # calculate reference 172 clcnd = zeros(dsoln.shape, dtype=dsoln.dtype) 173 for iistep in range(self.nsteps*2): 174 clcnd += svr.clcnd[ngstcell:]*self.time_increment/2 175 # compare. 176 self.assertTrue((dsoln==clcnd).all())
177