Skip to content

Commit 06a4b55

Browse files
committed
Adding methods for obtaining nodes from Demographics objects by name. Final part of (Fix #690).
1 parent 06458a7 commit 06a4b55

File tree

2 files changed

+98
-2
lines changed

2 files changed

+98
-2
lines changed

emod_api/demographics/Demographics.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,14 @@ def _all_nodes(self) -> List[Node]:
460460
default_node = [self.default_node] if hasattr(self, 'default_node') else []
461461
return self.nodes + default_node
462462

463+
@property
464+
def _all_node_names(self) -> List[int]:
465+
return [node.name for node in self._all_nodes]
466+
467+
@property
468+
def _all_nodes_by_name(self) -> Dict[int, Node]:
469+
return {node.name: node for node in self._all_nodes}
470+
463471
@property
464472
def _all_node_ids(self) -> List[int]:
465473
return [node.id for node in self._all_nodes]
@@ -470,7 +478,7 @@ def _all_nodes_by_id(self) -> Dict[int, Node]:
470478

471479
def get_node_by_id(self, node_id: int) -> Node:
472480
"""
473-
Returns the Node objects requested by their node id.
481+
Returns the Node object requested by its node id.
474482
475483
Args:
476484
node_id: a node_id to use in retrieving the requested Node object. None or 0 for 'the default node'.
@@ -505,6 +513,44 @@ def get_nodes_by_id(self, node_ids: List[int]) -> Dict[int, Node]:
505513
requested_nodes = {node_id: node for node_id, node in self._all_nodes_by_id.items() if node_id in node_ids}
506514
return requested_nodes
507515

516+
def get_node_by_name(self, node_name: str) -> Node:
517+
"""
518+
Returns the Node object requested by its node name.
519+
520+
Args:
521+
node_name: a node_name to use in retrieving the requested Node object. None for 'the default node'.
522+
523+
Returns:
524+
a Node object
525+
"""
526+
return list(self.get_nodes_by_name(node_names=[node_name]).values())[0]
527+
528+
def get_nodes_by_name(self, node_names: List[str]) -> Dict[str, Node]:
529+
"""
530+
Returns the Node objects requested by their node name.
531+
532+
Args:
533+
node_names: a list of node names to use in retrieving Node objects. None for 'the default node'.
534+
535+
Returns:
536+
a dict with name: node entries
537+
"""
538+
# replace a None name (default node) request with the default node's name
539+
if node_names is None:
540+
node_names = [self.default_node.name]
541+
if None in node_names:
542+
node_names.remove(None)
543+
node_names.append(self.default_node.name)
544+
545+
missing_node_names = [node_name for node_name in node_names if node_name not in self._all_node_names]
546+
if len(missing_node_names) > 0:
547+
msg = ', '.join([str(node_name) for node_name in missing_node_names])
548+
raise self.UnknownNodeException(f"The following node name(s) were requested but do not exist in this demographics "
549+
f"object:\n{msg}")
550+
requested_nodes = {node_name: node for node_name, node in self._all_nodes_by_name.items()
551+
if node_name in node_names}
552+
return requested_nodes
553+
508554
def SetMigrationPattern(self, pattern: str = "rwd"):
509555
"""
510556
Set migration pattern. Migration is enabled implicitly.

tests/test_demog.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,57 @@ def setUp(self) -> None:
2424
print(f"\n{self._testMethodName} started...")
2525
self.out_folder = manifest.demo_folder
2626

27+
def test_get_node_by_name(self):
28+
from emod_api.demographics.Demographics import DemographicsBase
29+
30+
mars = Node.Node(lat=0, lon=0, pop=100, name='Mars')
31+
venus = Node.Node(lat=0, lon=0, pop=100, name='Venus')
32+
earth = Node.Node(lat=0, lon=0, pop=100, name='Earth')
33+
nodes = [mars, venus]
34+
demographics = Demographics.Demographics(nodes=nodes, default_node=earth)
35+
36+
# a non default node
37+
node = demographics.get_node_by_name(node_name=mars.name)
38+
self.assertEqual(node, nodes[0])
39+
40+
# the default node, checked explicitly then implicitly
41+
node = demographics.get_node_by_name(node_name=earth.name)
42+
self.assertEqual(node, earth)
43+
44+
node = demographics.get_node_by_name(node_name=None)
45+
self.assertEqual(node, earth)
46+
47+
# a node name that does not exist (yet, at least!)
48+
self.assertRaises(DemographicsBase.UnknownNodeException, demographics.get_node_by_name, node_name='Planet X')
49+
50+
def test_get_nodes_by_name(self):
51+
from emod_api.demographics.Demographics import DemographicsBase
52+
53+
mars = Node.Node(lat=0, lon=0, pop=100, name='Mars')
54+
venus = Node.Node(lat=0, lon=0, pop=100, name='Venus')
55+
earth = Node.Node(lat=0, lon=0, pop=100, name='Earth')
56+
nodes = [mars, venus]
57+
demographics = Demographics.Demographics(nodes=nodes, default_node=earth)
58+
59+
# just getting some nodes, also checking explicit default node request, too.
60+
nodes = demographics.get_nodes_by_name(node_names=['Mars', 'Earth'])
61+
expected = {earth.name: earth, mars.name: mars}
62+
self.assertEqual(nodes, expected)
63+
64+
# verify that a node name of None will yield the default node
65+
nodes = demographics.get_nodes_by_name(node_names=['Mars', None])
66+
expected = {earth.name: earth, mars.name: mars}
67+
self.assertEqual(nodes, expected)
68+
69+
nodes = demographics.get_nodes_by_name(node_names=None)
70+
expected = {earth.name: earth}
71+
self.assertEqual(nodes, expected)
72+
73+
# a node name that does not exist (yet, at least!)
74+
self.assertRaises(DemographicsBase.UnknownNodeException, demographics.get_nodes_by_name,
75+
node_names=['Planet X', earth.name])
76+
77+
2778
def test_demo_basic_node(self):
2879
out_filename = os.path.join(self.out_folder, "demographics_basic_node.json")
2980
demog = Demographics.from_template_node()
@@ -2609,7 +2660,6 @@ def test_SetEquilibriumVitalDynamicsFromWorldBank_EH_06(self):
26092660
self.from_csv_with_country("Least developed countries: UN classification", 1980)
26102661
def test_SetEquilibriumVitalDynamicsFromWorldBank_EH_07(self):
26112662
self.from_csv_with_country("Turks and Caicos Islands", 1980)
2612-
26132663
#endregion
26142664
if __name__ == '__main__':
26152665
unittest.main()

0 commit comments

Comments
 (0)