Skip to content

Commit f5d81dd

Browse files
authored
add parameters --modelfile (#1286)
* Fix wasm-pack not found error * add parameters --modelfile
1 parent a6bab7f commit f5d81dd

File tree

4 files changed

+25
-8
lines changed

4 files changed

+25
-8
lines changed

visualdl/reader/graph_reader.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class GraphReader(object):
4141
"""Graph reader to read vdl graph files, support for frontend api in lib.py.
4242
"""
4343

44-
def __init__(self, logdir=''):
44+
def __init__(self, logdir='', model_name=''):
4545
"""Instance of GraphReader
4646
4747
Args:
@@ -52,6 +52,7 @@ def __init__(self, logdir=''):
5252
else:
5353
self.dir = logdir
5454

55+
self.model_name = model_name
5556
self.walks = {}
5657
self.displayname2runs = {}
5758
self.runs2displayname = {}
@@ -102,7 +103,10 @@ def graphs(self, update=False):
102103
]
103104
tags_temp.sort(reverse=True)
104105
if len(tags_temp) > 0:
105-
walks_temp.update({run: tags_temp[0]})
106+
if self.model_name:
107+
walks_temp.update({run: self.model_name})
108+
else:
109+
walks_temp.update({run: tags_temp[0]})
106110
self.walks = walks_temp
107111
return self.walks
108112

visualdl/server/api.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,13 @@ def try_call(function, *args, **kwargs):
7777

7878

7979
class Api(object):
80-
def __init__(self, logdir, model, cache_timeout):
80+
def __init__(self, logdir, model, modelfile, cache_timeout):
81+
self.model_name = ''
82+
if not logdir and modelfile:
83+
logdir = os.path.dirname(modelfile)
84+
self.model_name = os.path.basename(modelfile)
8185
self._reader = LogReader(logdir)
82-
self._graph_reader = GraphReader(logdir)
86+
self._graph_reader = GraphReader(logdir, self.model_name)
8387
self._graph_reader.set_displayname(self._reader)
8488
if model:
8589
if 'vdlgraph' in model:
@@ -415,7 +419,7 @@ def get_component_tabs(*apis, vdl_args, request_args):
415419
all_tabs = set()
416420
if vdl_args.component_tabs:
417421
return list(vdl_args.component_tabs)
418-
if vdl_args.logdir:
422+
if vdl_args.logdir or vdl_args.modelfile:
419423
for api in apis:
420424
all_tabs.update(api('component_tabs', request_args))
421425
all_tabs.add('static_graph')
@@ -427,8 +431,8 @@ def get_component_tabs(*apis, vdl_args, request_args):
427431
return list(all_tabs)
428432

429433

430-
def create_api_call(logdir, model, cache_timeout):
431-
api = Api(logdir, model, cache_timeout)
434+
def create_api_call(logdir, model, modelfile, cache_timeout):
435+
api = Api(logdir, model, modelfile, cache_timeout)
432436
routes = {
433437
'components': (api.components, []),
434438
'runs': (api.runs, []),

visualdl/server/app.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def get_locale():
8686
) # we add this to prevent SIGINT not work in multiprocess queue waiting
8787
babel = Babel(app, locale_selector=get_locale) # noqa:F841
8888
# Babel api from flask_babel v3.0.0
89-
api_call = create_api_call(args.logdir, args.model, args.cache_timeout)
89+
api_call = create_api_call(args.logdir, args.model, args.modelfile, args.cache_timeout)
9090
profiler_api_call = create_profiler_api_call(args.logdir)
9191
inference_api_call = create_model_convert_api_call()
9292
fastdeploy_api_call = create_fastdeploy_api_call()

visualdl/server/args.py

+9
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def __init__(self, args):
4040
self.api_only = args.get('api_only', False)
4141
self.open_browser = args.get('open_browser', False)
4242
self.model = args.get('model', '')
43+
self.modelfile = args.get('modelfile', '')
4344
self.product = args.get('product', default_product)
4445
self.telemetry = args.get('telemetry', True)
4546
self.theme = args.get('theme', None)
@@ -123,6 +124,7 @@ def __init__(self, **kwargs):
123124
self.api_only = args.api_only
124125
self.open_browser = args.open_browser
125126
self.model = args.model
127+
self.modelfile = args.modelfile
126128
self.product = args.product
127129
self.telemetry = args.telemetry
128130
self.theme = args.theme
@@ -141,6 +143,13 @@ def parse_args():
141143
epilog="For more information: https://github.com/PaddlePaddle/VisualDL"
142144
)
143145

146+
parser.add_argument(
147+
"--modelfile",
148+
type=str,
149+
action="store",
150+
default="",
151+
help="json model file path")
152+
144153
parser.add_argument(
145154
"--logdir", action="store", nargs="+", help="log file directory")
146155

0 commit comments

Comments
 (0)