From 91a0f8652a92d13ecba821aab4208b1c049d0e1c Mon Sep 17 00:00:00 2001 From: jwatson Date: Thu, 7 Aug 2025 12:49:03 -0700 Subject: [PATCH 1/8] implement file download, take 2 with cursor --- .../rag/configuration/AppConfiguration.java | 12 +++ .../cai/rag/files/RagFileController.java | 23 +++++ .../cai/rag/files/RagFileService.java | 20 ++++- .../cai/rag/files/RagFileControllerTest.java | 84 +++++++++++++++++++ .../cai/rag/files/RagFileServiceTest.java | 57 ++++++++++++- .../ManageTab/UploadedFilesTable.tsx | 20 +++-- 6 files changed, 206 insertions(+), 10 deletions(-) diff --git a/backend/src/main/java/com/cloudera/cai/rag/configuration/AppConfiguration.java b/backend/src/main/java/com/cloudera/cai/rag/configuration/AppConfiguration.java index 032af662c..8800ea26d 100644 --- a/backend/src/main/java/com/cloudera/cai/rag/configuration/AppConfiguration.java +++ b/backend/src/main/java/com/cloudera/cai/rag/configuration/AppConfiguration.java @@ -38,8 +38,11 @@ package com.cloudera.cai.rag.configuration; +import com.cloudera.cai.rag.files.FileSystemRagFileDownloader; import com.cloudera.cai.rag.files.FileSystemRagFileUploader; +import com.cloudera.cai.rag.files.RagFileDownloader; import com.cloudera.cai.rag.files.RagFileUploader; +import com.cloudera.cai.rag.files.S3RagFileDownloader; import com.cloudera.cai.rag.files.S3RagFileUploader; import com.cloudera.cai.util.reconcilers.ReconcilerConfig; import com.cloudera.cai.util.s3.AmazonS3Client; @@ -162,6 +165,15 @@ public RagFileUploader ragFileUploader(S3Config configuration) { return new S3RagFileUploader(s3Client, configuration.getBucketName()); } + @Bean + public RagFileDownloader ragFileDownloader(S3Config configuration) { + if (configuration.getBucketName().isEmpty()) { + return new FileSystemRagFileDownloader(); + } + AmazonS3Client s3Client = new AmazonS3Client(configuration); + return new S3RagFileDownloader(s3Client, configuration.getBucketName()); + } + public static String getLlmServiceUrl() { var llmServiceUrl = Optional.ofNullable(System.getenv("LLM_SERVICE_URL")) diff --git a/backend/src/main/java/com/cloudera/cai/rag/files/RagFileController.java b/backend/src/main/java/com/cloudera/cai/rag/files/RagFileController.java index 8feb1b794..f5c667012 100644 --- a/backend/src/main/java/com/cloudera/cai/rag/files/RagFileController.java +++ b/backend/src/main/java/com/cloudera/cai/rag/files/RagFileController.java @@ -46,8 +46,12 @@ import java.util.List; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; import org.springframework.web.bind.annotation.*; import org.springframework.web.multipart.MultipartFile; +import org.springframework.web.servlet.mvc.method.annotation.StreamingResponseBody; @RestController @Slf4j @@ -86,4 +90,23 @@ public List getRagDocuments(@PathVariable Long dataSourceId) public void deleteRagFile(@PathVariable Long id, @PathVariable Long dataSourceId) { ragFileService.deleteRagFile(id, dataSourceId); } + + @GetMapping(value = "/dataSources/{dataSourceId}/files/{documentId}/download") + public ResponseEntity downloadRagDocument( + @PathVariable Long dataSourceId, @PathVariable String documentId) { + var downloaded = ragFileService.downloadDocument(dataSourceId, documentId); + String filename = downloaded.filename(); + + StreamingResponseBody body = + outputStream -> { + try (var in = downloaded.stream()) { + in.transferTo(outputStream); + } + }; + + return ResponseEntity.ok() + .header(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"" + filename + "\"") + .contentType(MediaType.APPLICATION_OCTET_STREAM) + .body(body); + } } diff --git a/backend/src/main/java/com/cloudera/cai/rag/files/RagFileService.java b/backend/src/main/java/com/cloudera/cai/rag/files/RagFileService.java index 0b55faf9a..cc8a395eb 100644 --- a/backend/src/main/java/com/cloudera/cai/rag/files/RagFileService.java +++ b/backend/src/main/java/com/cloudera/cai/rag/files/RagFileService.java @@ -73,6 +73,7 @@ public class RagFileService { private final RagDataSourceRepository ragDataSourceRepository; private final RagFileDeleteReconciler ragFileDeleteReconciler; private final RagFileSummaryReconciler ragFileSummaryReconciler; + private final RagFileDownloader ragFileDownloader; @Autowired public RagFileService( @@ -83,7 +84,8 @@ public RagFileService( @Qualifier("s3BucketPrefix") String s3PathPrefix, RagDataSourceRepository ragDataSourceRepository, RagFileDeleteReconciler ragFileDeleteReconciler, - RagFileSummaryReconciler ragFileSummaryReconciler) { + RagFileSummaryReconciler ragFileSummaryReconciler, + RagFileDownloader ragFileDownloader) { this.idGenerator = idGenerator; this.ragFileRepository = ragFileRepository; this.ragFileUploader = ragFileUploader; @@ -92,6 +94,7 @@ public RagFileService( this.ragDataSourceRepository = ragDataSourceRepository; this.ragFileDeleteReconciler = ragFileDeleteReconciler; this.ragFileSummaryReconciler = ragFileSummaryReconciler; + this.ragFileDownloader = ragFileDownloader; } public List saveRagFile( @@ -214,6 +217,18 @@ public List getRagDocuments(Long dataSourceId) { return ragFileRepository.getRagDocuments(dataSourceId); } + public record DownloadedDocument(String filename, String s3Path, java.io.InputStream stream) {} + + public DownloadedDocument downloadDocument(Long dataSourceId, String documentId) { + var document = ragFileRepository.findDocumentByDocumentId(documentId); + if (!document.dataSourceId().equals(dataSourceId)) { + throw new NotFound( + "Document with id " + documentId + " not found for dataSourceId: " + dataSourceId); + } + var inputStream = ragFileDownloader.openStream(document.s3Path()); + return new DownloadedDocument(document.filename(), document.s3Path(), inputStream); + } + public record MultipartUploadableFile(MultipartFile file) implements UploadableFile { @Override @@ -262,6 +277,7 @@ public static RagFileService createNull(String... dummyIds) { "prefix", RagDataSourceRepository.createNull(), RagFileDeleteReconciler.createNull(), - RagFileSummaryReconciler.createNull()); + RagFileSummaryReconciler.createNull(), + RagFileDownloader.createNull()); } } diff --git a/backend/src/test/java/com/cloudera/cai/rag/files/RagFileControllerTest.java b/backend/src/test/java/com/cloudera/cai/rag/files/RagFileControllerTest.java index a333f37b4..b6cf77bae 100644 --- a/backend/src/test/java/com/cloudera/cai/rag/files/RagFileControllerTest.java +++ b/backend/src/test/java/com/cloudera/cai/rag/files/RagFileControllerTest.java @@ -47,13 +47,19 @@ import com.cloudera.cai.rag.Types; import com.cloudera.cai.rag.Types.RagDocument; import com.cloudera.cai.rag.datasources.RagDataSourceRepository; +import com.cloudera.cai.util.IdGenerator; import com.cloudera.cai.util.exceptions.BadRequest; +import com.cloudera.cai.util.exceptions.NotFound; +import java.io.ByteArrayOutputStream; import java.util.List; import java.util.Random; import java.util.UUID; import org.junit.jupiter.api.Test; +import org.springframework.http.HttpHeaders; +import org.springframework.http.ResponseEntity; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockMultipartFile; +import org.springframework.web.servlet.mvc.method.annotation.StreamingResponseBody; class RagFileControllerTest { @@ -164,4 +170,82 @@ void delete() { ragFileController.deleteRagFile(id, dataSourceId); assertThat(ragFileController.getRagDocuments(dataSourceId)).extracting("id").doesNotContain(id); } + + @Test + void download_success_streamsAttachment() throws Exception { + var dsRepo = RagDataSourceRepository.createNull(); + var repo = RagFileRepository.createNull(); + long dataSourceId = TestData.createTestDataSource(dsRepo); + String documentId = UUID.randomUUID().toString(); + String originalFilename = "mydoc.pdf"; + byte[] bytes = "hello world".getBytes(); + String prefix = "prefix"; + String s3Path = prefix + "/" + dataSourceId + "/" + documentId; + + RagFileService ragFileService = + new RagFileService( + IdGenerator.createNull(documentId), + repo, + RagFileUploader.createNull(), + RagFileIndexReconciler.createNull(), + prefix, + dsRepo, + RagFileDeleteReconciler.createNull(), + RagFileSummaryReconciler.createNull(), + RagFileDownloader.createNull(java.util.Map.of(s3Path, bytes))); + RagFileController controller = new RagFileController(ragFileService); + + // First upload to create metadata with known documentId + var request = new MockHttpServletRequest(); + TestData.addUserToRequest(request); + controller.uploadRagDocument( + new MockMultipartFile("file", originalFilename, "application/pdf", bytes), + dataSourceId, + request); + + ResponseEntity response = + controller.downloadRagDocument(dataSourceId, documentId); + assertThat(response.getStatusCode().is2xxSuccessful()).isTrue(); + assertThat(response.getHeaders().getFirst(HttpHeaders.CONTENT_DISPOSITION)) + .isEqualTo("attachment; filename=\"" + originalFilename + "\""); + StreamingResponseBody body = response.getBody(); + assertThat(body).isNotNull(); + ByteArrayOutputStream out = new ByteArrayOutputStream(); + body.writeTo(out); + assertThat(out.toByteArray()).isEqualTo(bytes); + } + + @Test + void download_wrongDataSource_throwsNotFound() { + var dsRepo = RagDataSourceRepository.createNull(); + var repo = RagFileRepository.createNull(); + long dataSourceId = TestData.createTestDataSource(dsRepo); + String documentId = UUID.randomUUID().toString(); + String prefix = "prefix"; + String s3Path = prefix + "/" + dataSourceId + "/" + documentId; + + RagFileService ragFileService = + new RagFileService( + IdGenerator.createNull(documentId), + repo, + RagFileUploader.createNull(), + RagFileIndexReconciler.createNull(), + prefix, + dsRepo, + RagFileDeleteReconciler.createNull(), + RagFileSummaryReconciler.createNull(), + RagFileDownloader.createNull(java.util.Map.of(s3Path, "x".getBytes()))); + RagFileController controller = new RagFileController(ragFileService); + + // Create the metadata + var request = new MockHttpServletRequest(); + TestData.addUserToRequest(request); + controller.uploadRagDocument( + new MockMultipartFile("file", "f.txt", "text/plain", new byte[] {1, 2, 3}), + dataSourceId, + request); + + assertThatThrownBy(() -> controller.downloadRagDocument(Long.MAX_VALUE, documentId)) + .isInstanceOf(NotFound.class); + } } diff --git a/backend/src/test/java/com/cloudera/cai/rag/files/RagFileServiceTest.java b/backend/src/test/java/com/cloudera/cai/rag/files/RagFileServiceTest.java index ea3124a42..ee8b15a6d 100644 --- a/backend/src/test/java/com/cloudera/cai/rag/files/RagFileServiceTest.java +++ b/backend/src/test/java/com/cloudera/cai/rag/files/RagFileServiceTest.java @@ -224,7 +224,8 @@ private RagFileService createRagFileService( prefix, dataSourceRepository, RagFileDeleteReconciler.createNull(), - RagFileSummaryReconciler.createNull()); + RagFileSummaryReconciler.createNull(), + RagFileDownloader.createNull()); } private long newDataSourceId() { @@ -311,4 +312,58 @@ private MockMultipartFile createZipFile(String[][] fileEntries, String contentTy } return new MockMultipartFile("test.zip", "test.zip", contentType, outputStream.toByteArray()); } + + @Test + void downloadDocument_success() throws Exception { + var repo = RagFileRepository.createNull(); + var dsRepo = RagDataSourceRepository.createNull(); + long dataSourceId = TestData.createTestDataSource(dsRepo); + String documentId = java.util.UUID.randomUUID().toString(); + // Insert a document + Long id = TestData.createTestDocument(dataSourceId, documentId, repo); + // Build service with same repo and a downloader that returns bytes + byte[] content = "hello world".getBytes(); + RagFileService service = + new RagFileService( + IdGenerator.createNull(), + repo, + RagFileUploader.createNull(), + RagFileIndexReconciler.createNull(), + "prefix", + dsRepo, + RagFileDeleteReconciler.createNull(), + RagFileSummaryReconciler.createNull(), + RagFileDownloader.createNull(java.util.Map.of("doesn't matter", content))); + + var downloaded = service.downloadDocument(dataSourceId, documentId); + assertThat(downloaded.filename()).isNotNull(); + try (var in = downloaded.stream()) { + byte[] read = in.readAllBytes(); + assertThat(read).isNotEmpty(); + } + } + + @Test + void downloadDocument_wrongDataSourceId() { + var repo = RagFileRepository.createNull(); + var dsRepo = RagDataSourceRepository.createNull(); + long dataSourceId = TestData.createTestDataSource(dsRepo); + String documentId = java.util.UUID.randomUUID().toString(); + TestData.createTestDocument(dataSourceId, documentId, repo); + + RagFileService service = + new RagFileService( + IdGenerator.createNull(), + repo, + RagFileUploader.createNull(), + RagFileIndexReconciler.createNull(), + "prefix", + dsRepo, + RagFileDeleteReconciler.createNull(), + RagFileSummaryReconciler.createNull(), + RagFileDownloader.createNull()); + + assertThatThrownBy(() -> service.downloadDocument(Long.MAX_VALUE, documentId)) + .isInstanceOf(NotFound.class); + } } diff --git a/ui/src/pages/DataSources/ManageTab/UploadedFilesTable.tsx b/ui/src/pages/DataSources/ManageTab/UploadedFilesTable.tsx index a6dae7795..67503f046 100644 --- a/ui/src/pages/DataSources/ManageTab/UploadedFilesTable.tsx +++ b/ui/src/pages/DataSources/ManageTab/UploadedFilesTable.tsx @@ -46,6 +46,7 @@ import { Typography, } from "antd"; import Icon, { DeleteOutlined } from "@ant-design/icons"; +import { DownloadOutlined } from "@ant-design/icons"; import { RagDocumentResponseType, useDeleteDocumentMutation, @@ -63,6 +64,7 @@ import { cdlWhite } from "src/cuix/variables.ts"; import ReadyColumn from "pages/DataSources/ManageTab/ReadyColumn.tsx"; import SummaryColumn from "pages/DataSources/ManageTab/SummaryColumn.tsx"; import { ColumnsType } from "antd/es/table"; +import { paths, ragPath } from "src/api/utils.ts"; const columns = ( dataSourceId: string, @@ -152,14 +154,18 @@ const columns = ( { title: "Actions", render: (_, record) => { + const url = `${ragPath}/${paths.dataSources}/${record.dataSourceId.toString()}/${paths.files}/${record.documentId}/download`; return ( -