Package solvcon :: Package tests :: Module test_solver
[hide private]
[frames] | no frames]

Source Code for Module solvcon.tests.test_solver

  1  import os 
  2  from unittest import TestCase 
  3  from ..testing import get_blk_from_sample_neu 
  4  from ..solver import BaseSolver, BlockSolver 
  5  from ..testing import TestingSolver 
6 7 -class CustomBaseSolver(BaseSolver):
8 - def __init__(self, **kw):
9 kw['neq'] = 1 10 super(CustomBaseSolver, self).__init__(**kw)
11 - def bind(self):
12 super(CustomBaseSolver, self).bind() 13 self.val = 'bind'
14
15 -class CustomBlockSolver(TestingSolver):
16 MESG_FILENAME_DEFAULT = os.devnull
17
18 -class TestBase(TestCase):
19 - def test_base(self):
20 self.assertRaises(KeyError, BaseSolver) 21 bsvr = BaseSolver(neq=1) 22 self.assertEqual(getattr(bsvr, 'val', None), None) 23 bsvr.bind() 24 self.assertEqual(getattr(bsvr, 'val', None), None)
25
26 - def test_inheritance(self):
27 svr = CustomBaseSolver() 28 self.assertEqual(getattr(svr, 'val', None), None) 29 svr.bind() 30 self.assertEqual(svr.val, 'bind')
31
32 -class TestFpdtype(TestCase):
33 - def test_fp(self):
34 from ..dependency import pointer_of, str_of 35 from ..conf import env 36 bsvr = BaseSolver(neq=1) 37 self.assertEqual(bsvr.fpdtype, env.fpdtype) 38 self.assertEqual(bsvr.fpdtypestr, str_of(env.fpdtype)) 39 self.assertEqual(bsvr.fpptr, pointer_of(env.fpdtype))
40
41 -class TestBlock(TestCase):
42 neq = 1 43 44 @staticmethod
45 - def _get_block():
47 48 @classmethod
49 - def _get_solver(cls, init=True):
50 import warnings 51 svr = CustomBlockSolver(cls._get_block(), neq=cls.neq, enable_mesg=True) 52 if init: 53 warnings.simplefilter("ignore") 54 svr.bind() 55 svr.init() 56 warnings.resetwarnings() 57 return svr
58
59 - def test_debug(self):
60 import sys 61 from cStringIO import StringIO 62 CustomBlockSolver.MESG_FILENAME_DEFAULT = 'sys.stdout' 63 stdout = sys.stdout 64 sys.stdout = StringIO() 65 svr = self._get_solver(init=True) 66 svr.mesg('test message') 67 self.assertEqual(sys.stdout.getvalue(), 'test message') 68 sys.stdout = stdout 69 CustomBlockSolver.MESG_FILENAME_DEFAULT = os.devnull
70
71 - def test_create(self):
72 svr = self._get_solver() 73 self.assertTrue(svr)
74
75 - def test_bound_full(self):
76 svr = self._get_solver() 77 svr.bind() 78 self.assertTrue(svr.is_bound) 79 self.assertFalse(svr.is_unbound)
80
81 - def test_unbound_full(self):
82 svr = self._get_solver() 83 svr.unbind() 84 self.assertFalse(svr.is_bound) 85 self.assertTrue(svr.is_unbound)
86
87 - def test_neq(self):
88 svr = self._get_solver() 89 self.assertEqual(svr.neq, self.neq)
90
91 - def test_blkn(self):
92 svr = self._get_solver() 93 self.assertEqual(svr.svrn, None)
94
95 - def test_metric(self):
96 svr = self._get_solver() 97 self.assertEqual(len(svr.cecnd.shape), 3) 98 self.assertEqual(svr.cecnd.shape[0], svr.ncell+svr.ngstcell) 99 self.assertEqual(svr.cecnd.shape[1], svr.CLMFC+1) 100 self.assertEqual(svr.cecnd.shape[2], svr.ndim) 101 self.assertEqual(len(svr.cevol.shape), 2) 102 self.assertEqual(svr.cevol.shape[0], svr.ncell+svr.ngstcell) 103 self.assertEqual(svr.cevol.shape[1], svr.CLMFC+1)
104
105 - def test_solution(self):
106 svr = self._get_solver() 107 self.assertEqual(len(svr.sol.shape), 2) 108 self.assertEqual(svr.sol.shape[0], svr.soln.shape[0]) 109 self.assertEqual(svr.sol.shape[1], svr.soln.shape[1]) 110 self.assertEqual(svr.sol.shape[0], svr.ncell+svr.ngstcell) 111 self.assertEqual(svr.sol.shape[1], svr.neq) 112 self.assertEqual(len(svr.dsol.shape), 3) 113 self.assertEqual(svr.dsol.shape[0], svr.dsoln.shape[0]) 114 self.assertEqual(svr.dsol.shape[1], svr.dsoln.shape[1]) 115 self.assertEqual(svr.dsol.shape[2], svr.dsoln.shape[2]) 116 self.assertEqual(svr.dsol.shape[0], svr.ncell+svr.ngstcell) 117 self.assertEqual(svr.dsol.shape[1], svr.neq) 118 self.assertEqual(svr.dsol.shape[2], svr.ndim)
119 120 time = 0.0 121 time_increment = 1.0 122 nsteps = 10 123
124 - def _run_solver(self, time, time_increment, nsteps):
125 # initialize. 126 svr = self._get_solver() 127 svr.soln.fill(0.0) 128 svr.dsoln.fill(0.0) 129 # run. 130 svr.march(time, time_increment, nsteps) 131 return svr
132
133 - def test_soln(self):
134 from numpy import zeros 135 # run. 136 svr = self._run_solver(self.time, self.time_increment, self.nsteps) 137 ngstcell = svr.ngstcell 138 # get result. 139 soln = svr.soln[ngstcell:,0] 140 # calculate reference 141 clvol = zeros(soln.shape, dtype=soln.dtype) 142 for iistep in range(self.nsteps*2): 143 clvol += svr.clvol[ngstcell:]*self.time_increment/2 144 # compare. 145 self.assertTrue((soln==clvol).all())
146
147 - def test_dsoln(self):
148 from numpy import zeros 149 # run. 150 svr = self._run_solver(self.time, self.time_increment, self.nsteps) 151 ngstcell = svr.ngstcell 152 # get result. 153 dsoln = svr.dsoln[ngstcell:,0,:] 154 # calculate reference 155 clcnd = zeros(dsoln.shape, dtype=dsoln.dtype) 156 for iistep in range(self.nsteps*2): 157 clcnd += svr.clcnd[ngstcell:]*self.time_increment/2 158 # compare. 159 self.assertTrue((dsoln==clcnd).all())
160