7
7
from primehub .utils .optionals import file_flag , toggle_flag
8
8
from primehub .utils .validator import validate_connection
9
9
10
+ group_basic_info = """
11
+ fragment GroupBasicInfo on Group {
12
+ id
13
+ displayName
14
+ name
15
+ admins
16
+ quotaCpu
17
+ quotaGpu
18
+ quotaMemory
19
+ projectQuotaCpu
20
+ projectQuotaGpu
21
+ projectQuotaMemory
22
+ sharedVolumeCapacity
23
+ }
24
+ """
25
+
10
26
11
27
def invalid_config (message : str ):
12
28
example = """
@@ -844,7 +860,6 @@ def create(self, config: dict):
844
860
id
845
861
displayName
846
862
name
847
- admins
848
863
quotaCpu
849
864
quotaGpu
850
865
quotaMemory
@@ -853,9 +868,14 @@ def create(self, config: dict):
853
868
projectQuotaMemory
854
869
sharedVolumeCapacity
855
870
}
856
- """
871
+ """ + group_basic_info
857
872
858
873
apply_auto_fill (config )
874
+
875
+ # cannot specify admins when creating
876
+ if config .get ('admins' ):
877
+ config ['admins' ] = ''
878
+
859
879
results = self .request ({'data' : validate (config )}, query )
860
880
861
881
if 'data' not in results :
@@ -893,20 +913,7 @@ def list(self, **kwargs) -> Iterator:
893
913
}
894
914
}
895
915
}
896
- fragment GroupBasicInfo on Group {
897
- id
898
- displayName
899
- name
900
- admins
901
- quotaCpu
902
- quotaGpu
903
- quotaMemory
904
- projectQuotaCpu
905
- projectQuotaGpu
906
- projectQuotaMemory
907
- sharedVolumeCapacity
908
- }
909
- """
916
+ """ + group_basic_info
910
917
911
918
variables : dict = {'orderBy' : {}, 'where' : {}}
912
919
page = kwargs .get ('page' , 0 )
@@ -1012,10 +1019,13 @@ def get(self, id: str) -> dict:
1012
1019
1013
1020
if 'data' not in results :
1014
1021
return results
1022
+ group = results ['data' ]['group' ]
1023
+ if not group :
1024
+ return group
1015
1025
1016
- results [ 'data' ][ ' group' ] ['volumes' ] = results [ 'data' ][ ' group' ] .pop ('datasets' , '[]' )
1017
-
1018
- return results [ 'data' ][ ' group' ]
1026
+ group ['volumes' ] = group .pop ('datasets' , '[]' )
1027
+ self . _output_format_admins ( id , group )
1028
+ return group
1019
1029
1020
1030
def _everyone_group_id (self ) -> dict :
1021
1031
query = """
@@ -1070,31 +1080,40 @@ def update(self, id: str, config: dict):
1070
1080
"""
1071
1081
1072
1082
query = """
1073
- mutation UpdateGroup($data: GroupUpdateInput!, $where: GroupWhereUniqueInput!) {
1083
+ mutation UpdateGroup(
1084
+ $data: GroupUpdateInput!,
1085
+ $where: GroupWhereUniqueInput!
1086
+ ) {
1074
1087
updateGroup(data: $data, where: $where) {
1075
1088
...GroupBasicInfo
1076
1089
}
1077
1090
}
1078
- fragment GroupBasicInfo on Group {
1079
- id
1080
- displayName
1081
- name
1082
- admins
1083
- quotaCpu
1084
- quotaGpu
1085
- quotaMemory
1086
- projectQuotaCpu
1087
- projectQuotaGpu
1088
- projectQuotaMemory
1089
- sharedVolumeCapacity
1090
- }
1091
- """
1091
+ """ + group_basic_info
1092
+
1093
+ if config .get ('admins' ):
1094
+ config ['admins' ] = self ._transform_admins (id , config .get ('admins' , []))
1095
+
1092
1096
variables = {'where' : {'id' : id }, 'data' : validate (config , True )}
1093
1097
results = self .request (variables , query )
1094
1098
1095
1099
if 'data' not in results :
1096
1100
return results
1097
- return results ['data' ]['updateGroup' ]
1101
+
1102
+ updated_query = """
1103
+ query Group(
1104
+ $where: GroupWhereUniqueInput!
1105
+ ) {
1106
+ group(where: $where) {
1107
+ ...GroupBasicInfo
1108
+ }
1109
+ }
1110
+ """ + group_basic_info
1111
+ updated_results = self .request ({'where' : {'id' : id }}, updated_query )
1112
+ if 'data' not in updated_results :
1113
+ return updated_results
1114
+ updated_group = updated_results ['data' ]['group' ]
1115
+ self ._output_format_admins (id , updated_group )
1116
+ return updated_group
1098
1117
1099
1118
@cmd (name = 'delete' , description = 'Delete the group by id' , return_required = True )
1100
1119
def delete (self , id : str ) -> dict :
@@ -1121,5 +1140,42 @@ def delete(self, id: str) -> dict:
1121
1140
return results
1122
1141
return results ['data' ]['deleteGroup' ]
1123
1142
1143
+ def _transform_admins (self , id : str , user_ids : List [str ]):
1144
+ if len (user_ids ) == 0 :
1145
+ return ''
1146
+
1147
+ member_dict = {}
1148
+ users = self .primehub .admin .admin_groups .list_users (id )
1149
+ for user in users :
1150
+ user_id = user ['id' ]
1151
+ username = user ['username' ]
1152
+ member_dict [user_id ] = username
1153
+
1154
+ admin_usernames = []
1155
+ invalid_user_ids = []
1156
+ for user_id in user_ids :
1157
+ if user_id in member_dict :
1158
+ admin_usernames .append (member_dict [user_id ])
1159
+ else :
1160
+ invalid_user_ids .append (user_id )
1161
+
1162
+ if len (invalid_user_ids ) > 0 :
1163
+ _invalid_ids = ', ' .join (invalid_user_ids )
1164
+ msg = f'admins contain invalid user ids: { _invalid_ids } '
1165
+ raise PrimeHubException (msg )
1166
+ return ',' .join (admin_usernames )
1167
+
1168
+ def _output_format_admins (self , id : str , group : dict ):
1169
+ admin_users = []
1170
+ admin_usernames = group .get ('admins' , '' ).split (',' )
1171
+ users = self .primehub .admin .admin_groups .list_users (id )
1172
+ for user in users :
1173
+ if user ['username' ] in admin_usernames :
1174
+ admin_users .append (dict (
1175
+ id = user ['id' ],
1176
+ username = user ['username' ]
1177
+ ))
1178
+ group ['admins' ] = admin_users
1179
+
1124
1180
def help_description (self ):
1125
1181
return "Manage groups"
0 commit comments