Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 48 additions & 29 deletions src/tagstudio/core/library/alchemy/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@
DB_VERSION_LEGACY_KEY,
JSON_FILENAME,
SQL_FILENAME,
TAG_CHILDREN_QUERY,
)
from tagstudio.core.library.alchemy.db import make_tables
from tagstudio.core.library.alchemy.enums import (
Expand Down Expand Up @@ -555,6 +554,20 @@ def open_sqlite_library(self, library_dir: Path, is_new: bool) -> LibraryStatus:
# Convert file extension list to ts_ignore file, if a .ts_ignore file does not exist
self.migrate_sql_to_ts_ignore(library_dir)

session.execute(
text("CREATE INDEX IF NOT EXISTS idx_tags_name_shorthand ON tags (name, shorthand)")
)
session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_tag_parents_child_id ON tag_parents (child_id)"
)
)
session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_tag_entries_entry_id ON tag_entries (entry_id)"
)
)

# Update DB_VERSION
if loaded_db_version < DB_VERSION:
self.set_version(DB_VERSION_CURRENT_KEY, DB_VERSION)
Expand Down Expand Up @@ -1054,55 +1067,61 @@ def search_library(

return res

def search_tags(self, name: str | None, limit: int = 100) -> list[set[Tag]]:
def search_tags(self, name: str | None, limit: int = 100) -> tuple[list[Tag], list[Tag]]:
"""Return a list of Tag records matching the query."""
name = name or ""
name = name.lower()

def sort_key(text: str):
return (not text.startswith(name), len(text), text)

with Session(self.engine) as session:
query = select(Tag).outerjoin(TagAlias).order_by(func.lower(Tag.name))
query = query.options(
selectinload(Tag.parent_tags),
selectinload(Tag.aliases),
)
if limit > 0:
query = query.limit(limit)
query = select(Tag.id, Tag.name)

if limit > 0 and not name:
query = query.limit(limit).order_by(func.lower(Tag.name))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This causes the sorting to happen after truncating the results, which differs from the behaviour when searching, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this was not sorting by len(text) before truncating. The previous implementation only sorted priority results by len so I've updated sort_key to do that. Which will make this order_by statement and sort_key produce the same results when no query is provided.


if name:
query = query.where(
or_(
Tag.name.icontains(name),
Tag.shorthand.icontains(name),
TagAlias.name.icontains(name),
)
)

direct_tags = set(session.scalars(query))
ancestor_tag_ids: list[Tag] = []
for tag in direct_tags:
ancestor_tag_ids.extend(
list(session.scalars(TAG_CHILDREN_QUERY, {"tag_id": tag.id}))
)

ancestor_tags = session.scalars(
select(Tag)
.where(Tag.id.in_(ancestor_tag_ids))
.options(selectinload(Tag.parent_tags), selectinload(Tag.aliases))
)
tags = list(session.execute(query))

res = [
direct_tags,
{at for at in ancestor_tags if at not in direct_tags},
]
if name:
query = select(TagAlias.tag_id, TagAlias.name).where(TagAlias.name.icontains(name))
tags.extend(session.execute(query))

tags.sort(key=lambda t: sort_key(t[1]))
seen_ids = set()
tag_ids = []
for row in tags:
id = row[0]
if id in seen_ids:
continue
tag_ids.append(id)
seen_ids.add(id)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't this be written as the following?

tags = dict(tags)
tag_ids = sorted(tags.keys(), key=lambda t: sort_key(tags[t]))
del tags # not sure if this is makes a diff, but `tags` could become quite large and triggering gc on it sooner can't hurt

this is both simpler code wise and should be faster by only sorting the deduplicated list

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is so it will use the order from Tag.name or TagAlias.name depending on which comes first for each tag.

Copy link
Collaborator

@Computerdores Computerdores Sep 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in that case the following should work since dict.keys() maintains insertion order and dict deduplicates by key:

tags.sort(key=lambda t: sort_key(t[1]))
tag_ids = dict(tags).keys()  # get the deduplicated list of ids

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

other than that I think this is good to merge


logger.info(
"searching tags",
search=name,
limit=limit,
statement=str(query),
results=len(res),
results=len(tag_ids),
)

session.expunge_all()
if limit <= 0:
limit = len(tag_ids)
tag_ids = tag_ids[:limit]

return res
hierarchy = self.get_tag_hierarchy(tag_ids)
direct_tags = [hierarchy.pop(id) for id in tag_ids]
ancestor_tags = list(hierarchy.values())
ancestor_tags.sort(key=lambda t: sort_key(t.name))
return direct_tags, ancestor_tags

def update_entry_path(self, entry_id: int | Entry, path: Path) -> bool:
"""Set the path field of an entry.
Expand Down
27 changes: 5 additions & 22 deletions src/tagstudio/qt/mixed/tag_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,32 +218,15 @@ def update_tags(self, query: str | None = None):
self.scroll_layout.takeAt(self.scroll_layout.count() - 1).widget().deleteLater()
self.create_button_in_layout = False

# Get results for the search query
query_lower = "" if not query else query.lower()
# Only use the tag limit if it's an actual number (aka not "All Tags")
tag_limit = TagSearchPanel.tag_limit if isinstance(TagSearchPanel.tag_limit, int) else -1
tag_results: list[set[Tag]] = self.lib.search_tags(name=query, limit=tag_limit)
if self.exclude:
tag_results[0] = {t for t in tag_results[0] if t.id not in self.exclude}
tag_results[1] = {t for t in tag_results[1] if t.id not in self.exclude}

# Sort and prioritize the results
results_0 = list(tag_results[0])
results_0.sort(key=lambda tag: tag.name.lower())
results_1 = list(tag_results[1])
results_1.sort(key=lambda tag: tag.name.lower())
raw_results = list(results_0 + results_1)
priority_results: set[Tag] = set()
all_results: list[Tag] = []
direct_tags, ancestor_tags = self.lib.search_tags(name=query, limit=tag_limit)

if query and query.strip():
for tag in raw_results:
if tag.name.lower().startswith(query_lower):
priority_results.add(tag)
all_results = [t for t in direct_tags if t.id not in self.exclude]
for tag in ancestor_tags:
if tag.id not in self.exclude:
all_results.append(tag)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code previously handled self.exclude being None, is there a reason you removed that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its type is list[int] and I couldn't find any code that could cause it to be None.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good reason ^^, I missed that

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
all_results = [t for t in direct_tags if t.id not in self.exclude]
for tag in ancestor_tags:
if tag.id not in self.exclude:
all_results.append(tag)
all_results = [t for t in direct_tags if t.id not in self.exclude]
all_results += [t for t in ancestor_tags if t.id not in self.exclude]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ended up doing this to avoid creating extra lists.
all_results.extend(t for t in ancestor_tags if t.id not in self.exclude)


all_results = sorted(list(priority_results), key=lambda tag: len(tag.name)) + [
r for r in raw_results if r not in priority_results
]
if tag_limit > 0:
all_results = all_results[:tag_limit]

Expand Down
8 changes: 4 additions & 4 deletions tests/test_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,10 @@ def test_library_search(library: Library, entry_full: Entry):
def test_tag_search(library: Library):
tag = library.tags[0]

assert library.search_tags(tag.name.lower())
assert library.search_tags(tag.name.upper())
assert library.search_tags(tag.name[2:-2])
assert library.search_tags(tag.name * 2) == [set(), set()]
assert library.search_tags(tag.name.lower())[0]
assert library.search_tags(tag.name.upper())[0]
assert library.search_tags(tag.name[2:-2])[0]
assert library.search_tags(tag.name * 2) == ([], [])


def test_get_entry(library: Library, entry_min: Entry):
Expand Down