Skip to content

Commit e8b19bc

Browse files
authored
CU-869a9w9v8: Allow a warning instead of a raised exception when doing supervised training (#121)
1 parent 075e0eb commit e8b19bc

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

medcat-v2/medcat/trainer.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
# NOTE: this should be used for changing the CDB, both for training and for
2626
# unlinking concept/names.
2727
class Trainer:
28+
strict_train: bool = False
2829

2930
def __init__(self, cdb: CDB, caller: Callable[[str], MutableDocument],
3031
pipeline: Pipeline):
@@ -453,11 +454,17 @@ def _train_supervised_for_project2(self,
453454
context = "[...]" + context
454455
if context_end < len(cur_text) - 1:
455456
context += "[...]"
456-
raise ValueError(
457-
f"Failed to identify '{cui}' ({ann['value']}) "
458-
f"([{ann['start']}:{ann['end']}]) "
459-
f"in '{context}' {mut_entity} within document "
460-
f"{doc['id']} | {doc['name']}") from ve
457+
msg_template = (
458+
"Failed to identify '%s' (%s) ([%d:%d]) "
459+
"in '%s' %s within document %s | %s, "
460+
"skipping training for this example")
461+
msg_context = (
462+
cui, ann['value'], ann['start'], ann['end'],
463+
context, mut_entity, doc['id'], doc['name'])
464+
if self.strict_train:
465+
raise ValueError(msg_template % msg_context) from ve
466+
else:
467+
logger.warning(msg_template, *msg_context, exc_info=ve)
461468
if train_from_false_positives:
462469
fps: list[MutableEntity] = get_false_positives(doc, mut_doc)
463470

0 commit comments

Comments
 (0)