diff --git a/backend/src/main/java/com/cloudera/cai/rag/configuration/AppConfiguration.java b/backend/src/main/java/com/cloudera/cai/rag/configuration/AppConfiguration.java index 032af662c..8800ea26d 100644 --- a/backend/src/main/java/com/cloudera/cai/rag/configuration/AppConfiguration.java +++ b/backend/src/main/java/com/cloudera/cai/rag/configuration/AppConfiguration.java @@ -38,8 +38,11 @@ package com.cloudera.cai.rag.configuration; +import com.cloudera.cai.rag.files.FileSystemRagFileDownloader; import com.cloudera.cai.rag.files.FileSystemRagFileUploader; +import com.cloudera.cai.rag.files.RagFileDownloader; import com.cloudera.cai.rag.files.RagFileUploader; +import com.cloudera.cai.rag.files.S3RagFileDownloader; import com.cloudera.cai.rag.files.S3RagFileUploader; import com.cloudera.cai.util.reconcilers.ReconcilerConfig; import com.cloudera.cai.util.s3.AmazonS3Client; @@ -162,6 +165,15 @@ public RagFileUploader ragFileUploader(S3Config configuration) { return new S3RagFileUploader(s3Client, configuration.getBucketName()); } + @Bean + public RagFileDownloader ragFileDownloader(S3Config configuration) { + if (configuration.getBucketName().isEmpty()) { + return new FileSystemRagFileDownloader(); + } + AmazonS3Client s3Client = new AmazonS3Client(configuration); + return new S3RagFileDownloader(s3Client, configuration.getBucketName()); + } + public static String getLlmServiceUrl() { var llmServiceUrl = Optional.ofNullable(System.getenv("LLM_SERVICE_URL")) diff --git a/backend/src/main/java/com/cloudera/cai/rag/files/FileSystemRagFileDownloader.java b/backend/src/main/java/com/cloudera/cai/rag/files/FileSystemRagFileDownloader.java new file mode 100644 index 000000000..ff106738b --- /dev/null +++ b/backend/src/main/java/com/cloudera/cai/rag/files/FileSystemRagFileDownloader.java @@ -0,0 +1,32 @@ +package com.cloudera.cai.rag.files; + +import com.cloudera.cai.util.exceptions.NotFound; +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import lombok.extern.slf4j.Slf4j; + +@Slf4j +public class FileSystemRagFileDownloader implements RagFileDownloader { + private static final String FILE_STORAGE_ROOT = fileStoragePath(); + + private static String fileStoragePath() { + var fileStoragePath = System.getenv("RAG_DATABASES_DIR") + "/file_storage"; + log.info("configured with fileStoragePath = {}", fileStoragePath); + return fileStoragePath; + } + + @Override + public InputStream openStream(String s3Path) throws NotFound { + try { + Path filePath = Path.of(FILE_STORAGE_ROOT, s3Path); + if (!Files.exists(filePath)) { + throw new NotFound("no document found with storage path: " + s3Path); + } + return Files.newInputStream(filePath); + } catch (IOException e) { + throw new RuntimeException(e); + } + } +} diff --git a/backend/src/main/java/com/cloudera/cai/rag/files/RagFileController.java b/backend/src/main/java/com/cloudera/cai/rag/files/RagFileController.java index 8feb1b794..57b742fdd 100644 --- a/backend/src/main/java/com/cloudera/cai/rag/files/RagFileController.java +++ b/backend/src/main/java/com/cloudera/cai/rag/files/RagFileController.java @@ -46,8 +46,12 @@ import java.util.List; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; import org.springframework.web.bind.annotation.*; import org.springframework.web.multipart.MultipartFile; +import org.springframework.web.servlet.mvc.method.annotation.StreamingResponseBody; @RestController @Slf4j @@ -82,8 +86,27 @@ public List getRagDocuments(@PathVariable Long dataSourceId) return ragFileService.getRagDocuments(dataSourceId); } - @DeleteMapping(value = "/dataSources/{dataSourceId}/files/{id}") - public void deleteRagFile(@PathVariable Long id, @PathVariable Long dataSourceId) { - ragFileService.deleteRagFile(id, dataSourceId); + @DeleteMapping(value = "/dataSources/{dataSourceId}/files/{documentId}") + public void deleteRagFile(@PathVariable Long dataSourceId, @PathVariable String documentId) { + ragFileService.deleteRagFileByDocumentId(documentId, dataSourceId); + } + + @GetMapping(value = "/dataSources/{dataSourceId}/files/{documentId}/download") + public ResponseEntity downloadRagDocument( + @PathVariable Long dataSourceId, @PathVariable String documentId) { + var downloaded = ragFileService.downloadDocumentByDocumentId(dataSourceId, documentId); + String filename = downloaded.filename(); + + StreamingResponseBody body = + outputStream -> { + try (var in = downloaded.stream()) { + in.transferTo(outputStream); + } + }; + + return ResponseEntity.ok() + .header(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"" + filename + "\"") + .contentType(MediaType.APPLICATION_OCTET_STREAM) + .body(body); } } diff --git a/backend/src/main/java/com/cloudera/cai/rag/files/RagFileDownloader.java b/backend/src/main/java/com/cloudera/cai/rag/files/RagFileDownloader.java new file mode 100644 index 000000000..ade63a9dd --- /dev/null +++ b/backend/src/main/java/com/cloudera/cai/rag/files/RagFileDownloader.java @@ -0,0 +1,35 @@ +package com.cloudera.cai.rag.files; + +import com.cloudera.cai.util.exceptions.NotFound; +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.util.HashMap; +import java.util.Map; + +/** Downloader abstraction to open a streaming InputStream for a stored document. */ +public interface RagFileDownloader { + /** + * Opens a streaming InputStream for the provided storage path. The caller is responsible for + * closing the returned stream. + */ + InputStream openStream(String s3Path) throws NotFound; + + /** + * Test double that serves InputStreams from an in-memory byte[] map. Useful for unit tests + * without using a mocking framework. + */ + static RagFileDownloader createNull() { + return createNull(Map.of()); + } + + static RagFileDownloader createNull(Map pathToBytes) { + Map backing = new HashMap<>(pathToBytes); + return s3Path -> { + byte[] data = backing.get(s3Path); + if (data == null) { + throw new NotFound("no document found with storage path: " + s3Path); + } + return new ByteArrayInputStream(data); + }; + } +} diff --git a/backend/src/main/java/com/cloudera/cai/rag/files/RagFileService.java b/backend/src/main/java/com/cloudera/cai/rag/files/RagFileService.java index 0b55faf9a..9a1ed8de4 100644 --- a/backend/src/main/java/com/cloudera/cai/rag/files/RagFileService.java +++ b/backend/src/main/java/com/cloudera/cai/rag/files/RagFileService.java @@ -73,6 +73,7 @@ public class RagFileService { private final RagDataSourceRepository ragDataSourceRepository; private final RagFileDeleteReconciler ragFileDeleteReconciler; private final RagFileSummaryReconciler ragFileSummaryReconciler; + private final RagFileDownloader ragFileDownloader; @Autowired public RagFileService( @@ -83,7 +84,8 @@ public RagFileService( @Qualifier("s3BucketPrefix") String s3PathPrefix, RagDataSourceRepository ragDataSourceRepository, RagFileDeleteReconciler ragFileDeleteReconciler, - RagFileSummaryReconciler ragFileSummaryReconciler) { + RagFileSummaryReconciler ragFileSummaryReconciler, + RagFileDownloader ragFileDownloader) { this.idGenerator = idGenerator; this.ragFileRepository = ragFileRepository; this.ragFileUploader = ragFileUploader; @@ -92,6 +94,7 @@ public RagFileService( this.ragDataSourceRepository = ragDataSourceRepository; this.ragFileDeleteReconciler = ragFileDeleteReconciler; this.ragFileSummaryReconciler = ragFileSummaryReconciler; + this.ragFileDownloader = ragFileDownloader; } public List saveRagFile( @@ -201,12 +204,16 @@ private static String validateFilename(String originalFilename) { return originalFilename; } - public void deleteRagFile(Long id, Long dataSourceId) { - var document = ragFileRepository.getRagDocumentById(id); + public void deleteRagFileByDocumentId(String documentId, Long dataSourceId) { + var document = ragFileRepository.findDocumentByDocumentId(documentId); if (!document.dataSourceId().equals(dataSourceId)) { - throw new NotFound("Document with id " + id + " not found for dataSourceId: " + dataSourceId); + throw new NotFound( + "Document with documentId " + + documentId + + " not found for dataSourceId: " + + dataSourceId); } - ragFileRepository.deleteById(id); + ragFileRepository.deleteById(document.id()); ragFileDeleteReconciler.submit(document); } @@ -214,6 +221,30 @@ public List getRagDocuments(Long dataSourceId) { return ragFileRepository.getRagDocuments(dataSourceId); } + public record DownloadedDocument(String filename, String s3Path, java.io.InputStream stream) {} + + public DownloadedDocument downloadDocument(Long dataSourceId, Long id) { + var document = ragFileRepository.getRagDocumentById(id); + if (!document.dataSourceId().equals(dataSourceId)) { + throw new NotFound("Document with id " + id + " not found for dataSourceId: " + dataSourceId); + } + var inputStream = ragFileDownloader.openStream(document.s3Path()); + return new DownloadedDocument(document.filename(), document.s3Path(), inputStream); + } + + public DownloadedDocument downloadDocumentByDocumentId(Long dataSourceId, String documentId) { + var document = ragFileRepository.findDocumentByDocumentId(documentId); + if (!document.dataSourceId().equals(dataSourceId)) { + throw new NotFound( + "Document with documentId " + + documentId + + " not found for dataSourceId: " + + dataSourceId); + } + var inputStream = ragFileDownloader.openStream(document.s3Path()); + return new DownloadedDocument(document.filename(), document.s3Path(), inputStream); + } + public record MultipartUploadableFile(MultipartFile file) implements UploadableFile { @Override @@ -262,6 +293,7 @@ public static RagFileService createNull(String... dummyIds) { "prefix", RagDataSourceRepository.createNull(), RagFileDeleteReconciler.createNull(), - RagFileSummaryReconciler.createNull()); + RagFileSummaryReconciler.createNull(), + RagFileDownloader.createNull()); } } diff --git a/backend/src/main/java/com/cloudera/cai/rag/files/S3RagFileDownloader.java b/backend/src/main/java/com/cloudera/cai/rag/files/S3RagFileDownloader.java new file mode 100644 index 000000000..673e1eb02 --- /dev/null +++ b/backend/src/main/java/com/cloudera/cai/rag/files/S3RagFileDownloader.java @@ -0,0 +1,60 @@ +package com.cloudera.cai.rag.files; + +import com.cloudera.cai.util.exceptions.NotFound; +import com.cloudera.cai.util.s3.AmazonS3Client; +import com.cloudera.cai.util.s3.RefCountedS3Client; +import java.io.FilterInputStream; +import java.io.IOException; +import java.io.InputStream; +import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Qualifier; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.S3Exception; + +@Slf4j +public class S3RagFileDownloader implements RagFileDownloader { + private final AmazonS3Client s3Client; + private final String bucketName; + + public S3RagFileDownloader( + AmazonS3Client s3Client, @Qualifier("s3BucketName") String s3BucketName) { + this.s3Client = s3Client; + this.bucketName = s3BucketName; + } + + @Override + public InputStream openStream(String s3Path) throws NotFound { + log.info("Downloading file from S3: {}", s3Path); + GetObjectRequest request = GetObjectRequest.builder().bucket(bucketName).key(s3Path).build(); + RefCountedS3Client client = s3Client.getRefCountedClient(); + try { + InputStream inner = client.getClient().getObject(request); + return new ClosingInputStream(inner, client); + } catch (S3Exception e) { + if (e.statusCode() == 404) { + client.close(); + throw new NotFound("no document found with storage path: " + s3Path); + } + client.close(); + throw e; + } + } + + private static class ClosingInputStream extends FilterInputStream { + private final RefCountedS3Client client; + + protected ClosingInputStream(InputStream in, RefCountedS3Client client) { + super(in); + this.client = client; + } + + @Override + public void close() throws IOException { + try { + super.close(); + } finally { + client.close(); + } + } + } +} diff --git a/backend/src/main/java/com/cloudera/cai/util/s3/AmazonS3Client.java b/backend/src/main/java/com/cloudera/cai/util/s3/AmazonS3Client.java index f62b1bf1b..20a8b6751 100644 --- a/backend/src/main/java/com/cloudera/cai/util/s3/AmazonS3Client.java +++ b/backend/src/main/java/com/cloudera/cai/util/s3/AmazonS3Client.java @@ -155,7 +155,7 @@ private S3Client buildAccessKeyClient(Region awsRegion) { } public RefCountedS3Client getRefCountedClient() { - return new RefCountedS3Client(awsCredentials, s3Client, asyncClient, referenceCounter); + return new RefCountedS3Client(s3Client, asyncClient, referenceCounter); } private void initializeWithWebIdentity(Region awsRegion) { diff --git a/backend/src/main/java/com/cloudera/cai/util/s3/RefCountedS3Client.java b/backend/src/main/java/com/cloudera/cai/util/s3/RefCountedS3Client.java index 1522573ce..ae12e8bc7 100644 --- a/backend/src/main/java/com/cloudera/cai/util/s3/RefCountedS3Client.java +++ b/backend/src/main/java/com/cloudera/cai/util/s3/RefCountedS3Client.java @@ -40,23 +40,16 @@ import java.util.concurrent.atomic.AtomicInteger; import lombok.extern.slf4j.Slf4j; -import software.amazon.awssdk.auth.credentials.AwsCredentials; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.S3Client; @Slf4j public class RefCountedS3Client implements AutoCloseable { - private final AwsCredentials credentials; private final S3Client s3Client; private final S3AsyncClient asyncClient; private final AtomicInteger referenceCounter; - RefCountedS3Client( - AwsCredentials credentials, - S3Client client, - S3AsyncClient asyncClient, - AtomicInteger counter) { - this.credentials = credentials; + RefCountedS3Client(S3Client client, S3AsyncClient asyncClient, AtomicInteger counter) { this.s3Client = client; this.asyncClient = asyncClient; this.referenceCounter = counter; @@ -67,14 +60,6 @@ public S3Client getClient() { return s3Client; } - public S3AsyncClient getAsyncClient() { - return asyncClient; - } - - public AwsCredentials getCredentials() { - return credentials; - } - @Override public void close() { if (referenceCounter.decrementAndGet() == 0) { diff --git a/backend/src/test/java/com/cloudera/cai/rag/files/RagFileControllerTest.java b/backend/src/test/java/com/cloudera/cai/rag/files/RagFileControllerTest.java index a333f37b4..a3c8d00a0 100644 --- a/backend/src/test/java/com/cloudera/cai/rag/files/RagFileControllerTest.java +++ b/backend/src/test/java/com/cloudera/cai/rag/files/RagFileControllerTest.java @@ -47,13 +47,20 @@ import com.cloudera.cai.rag.Types; import com.cloudera.cai.rag.Types.RagDocument; import com.cloudera.cai.rag.datasources.RagDataSourceRepository; +import com.cloudera.cai.util.IdGenerator; import com.cloudera.cai.util.exceptions.BadRequest; +import com.cloudera.cai.util.exceptions.NotFound; +import java.io.ByteArrayOutputStream; import java.util.List; +import java.util.Map; import java.util.Random; import java.util.UUID; import org.junit.jupiter.api.Test; +import org.springframework.http.HttpHeaders; +import org.springframework.http.ResponseEntity; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockMultipartFile; +import org.springframework.web.servlet.mvc.method.annotation.StreamingResponseBody; class RagFileControllerTest { @@ -61,7 +68,7 @@ class RagFileControllerTest { private final RagFileRepository ragFileRepository = RagFileRepository.createNull(); @Test - void uploadFile() throws Exception { + void uploadFile() { RagFileController ragFileController = new RagFileController(RagFileService.createNull()); String fileName = "real-filename"; String contentType = "text/plain"; @@ -87,7 +94,7 @@ void uploadFile() throws Exception { } @Test - void uploadFile_noBytes() throws Exception { + void uploadFile_noBytes() { RagFileController ragFileController = new RagFileController(RagFileService.createNull()); String fileName = "file"; String contentType = "text/plain"; @@ -161,7 +168,100 @@ void delete() { var id = TestData.createTestDocument(dataSourceId, documentId, ragFileRepository); RagFileController ragFileController = new RagFileController(RagFileService.createNull()); - ragFileController.deleteRagFile(id, dataSourceId); + ragFileController.deleteRagFile(dataSourceId, documentId); assertThat(ragFileController.getRagDocuments(dataSourceId)).extracting("id").doesNotContain(id); } + + @Test + void download_success_streamsAttachment() throws Exception { + var dsRepo = RagDataSourceRepository.createNull(); + var repo = RagFileRepository.createNull(); + long dataSourceId = TestData.createTestDataSource(dsRepo); + String documentId = UUID.randomUUID().toString(); + String originalFilename = "mydoc.pdf"; + byte[] bytes = "hello world".getBytes(); + String prefix = "prefix"; + String s3Path = prefix + "/" + dataSourceId + "/" + documentId; + + RagFileService ragFileService = + new RagFileService( + IdGenerator.createNull(documentId), + repo, + RagFileUploader.createNull(), + RagFileIndexReconciler.createNull(), + prefix, + dsRepo, + RagFileDeleteReconciler.createNull(), + RagFileSummaryReconciler.createNull(), + RagFileDownloader.createNull(Map.of(s3Path, bytes))); + RagFileController controller = new RagFileController(ragFileService); + + // First upload to create metadata with known documentId + var request = new MockHttpServletRequest(); + TestData.addUserToRequest(request); + controller.uploadRagDocument( + new MockMultipartFile("file", originalFilename, "application/pdf", bytes), + dataSourceId, + request); + + // Find the created document id by filename + String foundDocumentId = + repo.getRagDocuments(dataSourceId).stream() + .filter(d -> d.filename().equals(originalFilename)) + .map(RagDocument::documentId) + .findFirst() + .orElseThrow(); + + ResponseEntity response = + controller.downloadRagDocument(dataSourceId, foundDocumentId); + assertThat(response.getStatusCode().is2xxSuccessful()).isTrue(); + assertThat(response.getHeaders().getFirst(HttpHeaders.CONTENT_DISPOSITION)) + .isEqualTo("attachment; filename=\"" + originalFilename + "\""); + StreamingResponseBody body = response.getBody(); + assertThat(body).isNotNull(); + ByteArrayOutputStream out = new ByteArrayOutputStream(); + body.writeTo(out); + assertThat(out.toByteArray()).isEqualTo(bytes); + } + + @Test + void download_wrongDataSource_throwsNotFound() { + var dsRepo = RagDataSourceRepository.createNull(); + var repo = RagFileRepository.createNull(); + long dataSourceId = TestData.createTestDataSource(dsRepo); + String documentId = UUID.randomUUID().toString(); + String prefix = "prefix"; + String s3Path = prefix + "/" + dataSourceId + "/" + documentId; + + RagFileService ragFileService = + new RagFileService( + IdGenerator.createNull(documentId), + repo, + RagFileUploader.createNull(), + RagFileIndexReconciler.createNull(), + prefix, + dsRepo, + RagFileDeleteReconciler.createNull(), + RagFileSummaryReconciler.createNull(), + RagFileDownloader.createNull(Map.of(s3Path, "x".getBytes()))); + RagFileController controller = new RagFileController(ragFileService); + + // Create the metadata + var request = new MockHttpServletRequest(); + TestData.addUserToRequest(request); + controller.uploadRagDocument( + new MockMultipartFile("file", "f.txt", "text/plain", new byte[] {1, 2, 3}), + dataSourceId, + request); + + String foundDocumentId = + repo.getRagDocuments(dataSourceId).stream() + .filter(d -> d.filename().equals("f.txt")) + .map(RagDocument::documentId) + .findFirst() + .orElseThrow(); + + assertThatThrownBy(() -> controller.downloadRagDocument(Long.MAX_VALUE, foundDocumentId)) + .isInstanceOf(NotFound.class); + } } diff --git a/backend/src/test/java/com/cloudera/cai/rag/files/RagFileServiceTest.java b/backend/src/test/java/com/cloudera/cai/rag/files/RagFileServiceTest.java index ea3124a42..2a5b38716 100644 --- a/backend/src/test/java/com/cloudera/cai/rag/files/RagFileServiceTest.java +++ b/backend/src/test/java/com/cloudera/cai/rag/files/RagFileServiceTest.java @@ -101,7 +101,7 @@ void deleteRagFile() { String documentId = UUID.randomUUID().toString(); var id = TestData.createTestDocument(dataSourceId, documentId, ragFileRepository); RagFileService ragFileService = createRagFileService(); - ragFileService.deleteRagFile(id, dataSourceId); + ragFileService.deleteRagFileByDocumentId(documentId, dataSourceId); assertThat(ragFileService.getRagDocuments(dataSourceId)).extracting("id").doesNotContain(id); } @@ -110,11 +110,12 @@ void deleteRagFile_wrongDataSourceId() { RagFileRepository ragFileRepository = RagFileRepository.createNull(); var dataSourceId = TestData.createTestDataSource(RagDataSourceRepository.createNull()); String documentId = UUID.randomUUID().toString(); - var id = TestData.createTestDocument(dataSourceId, documentId, ragFileRepository); + TestData.createTestDocument(dataSourceId, documentId, ragFileRepository); RagFileService ragFileService = createRagFileService(); Long nonExistentDataSourceId = Long.MAX_VALUE; - assertThatThrownBy(() -> ragFileService.deleteRagFile(id, nonExistentDataSourceId)) + assertThatThrownBy( + () -> ragFileService.deleteRagFileByDocumentId(documentId, nonExistentDataSourceId)) .isInstanceOf(NotFound.class); } @@ -224,7 +225,8 @@ private RagFileService createRagFileService( prefix, dataSourceRepository, RagFileDeleteReconciler.createNull(), - RagFileSummaryReconciler.createNull()); + RagFileSummaryReconciler.createNull(), + RagFileDownloader.createNull()); } private long newDataSourceId() { @@ -311,4 +313,80 @@ private MockMultipartFile createZipFile(String[][] fileEntries, String contentTy } return new MockMultipartFile("test.zip", "test.zip", contentType, outputStream.toByteArray()); } + + @Test + void downloadDocumentById_success() throws Exception { + var repo = RagFileRepository.createNull(); + var dsRepo = RagDataSourceRepository.createNull(); + long dataSourceId = TestData.createTestDataSource(dsRepo); + String documentId = UUID.randomUUID().toString(); + Long id = TestData.createTestDocument(dataSourceId, documentId, repo); + + byte[] content = "hello by id".getBytes(); + RagFileService service = + new RagFileService( + IdGenerator.createNull(), + repo, + RagFileUploader.createNull(), + RagFileIndexReconciler.createNull(), + "prefix", + dsRepo, + RagFileDeleteReconciler.createNull(), + RagFileSummaryReconciler.createNull(), + RagFileDownloader.createNull(java.util.Map.of("doesn't matter", content))); + + var downloaded = service.downloadDocument(dataSourceId, id); + assertThat(downloaded.filename()).isNotNull(); + try (var in = downloaded.stream()) { + byte[] read = in.readAllBytes(); + assertThat(read).isNotEmpty(); + } + } + + @Test + void downloadDocumentById_wrongDataSourceId() { + var repo = RagFileRepository.createNull(); + var dsRepo = RagDataSourceRepository.createNull(); + long dataSourceId = TestData.createTestDataSource(dsRepo); + String documentId = UUID.randomUUID().toString(); + Long id = TestData.createTestDocument(dataSourceId, documentId, repo); + + RagFileService service = + new RagFileService( + IdGenerator.createNull(), + repo, + RagFileUploader.createNull(), + RagFileIndexReconciler.createNull(), + "prefix", + dsRepo, + RagFileDeleteReconciler.createNull(), + RagFileSummaryReconciler.createNull(), + RagFileDownloader.createNull()); + + assertThatThrownBy(() -> service.downloadDocument(Long.MAX_VALUE, id)) + .isInstanceOf(NotFound.class); + } + + @Test + void downloadDocumentById_notFound() { + var repo = RagFileRepository.createNull(); + var dsRepo = RagDataSourceRepository.createNull(); + long dataSourceId = TestData.createTestDataSource(dsRepo); + + RagFileService service = + new RagFileService( + IdGenerator.createNull(), + repo, + RagFileUploader.createNull(), + RagFileIndexReconciler.createNull(), + "prefix", + dsRepo, + RagFileDeleteReconciler.createNull(), + RagFileSummaryReconciler.createNull(), + RagFileDownloader.createNull()); + + // Using a negative id to ensure it's not present + assertThatThrownBy(() -> service.downloadDocument(dataSourceId, -9999L)) + .isInstanceOf(NotFound.class); + } } diff --git a/ui/src/api/ragDocumentsApi.ts b/ui/src/api/ragDocumentsApi.ts index 8de33af47..7e59d88b2 100644 --- a/ui/src/api/ragDocumentsApi.ts +++ b/ui/src/api/ragDocumentsApi.ts @@ -78,7 +78,7 @@ const createRagDocumentsMutation = async ({ dataSourceId: string; }) => { const promises = files.map((file) => - createRagDocumentMutation(file, dataSourceId), + createRagDocumentMutation(file, dataSourceId) ); return Promise.allSettled(promises); @@ -93,16 +93,16 @@ const createRagDocumentMutation = async (file: File, dataSourceId: string) => { method: "POST", body: formData, headers: commonHeaders, - }, + } ).then((res) => { if (!res.ok) { if (res.status === 413) { throw new Error( - `File is too large. Maximum size is 100MB: ${file.name}`, + `File is too large. Maximum size is 100MB: ${file.name}` ); } throw new Error( - `Failed to call API backend. status: ${res.status.toString()} : ${res.statusText}`, + `Failed to call API backend. status: ${res.status.toString()} : ${res.statusText}` ); } return res.json() as Promise; @@ -138,7 +138,7 @@ export interface RagDocumentResponseType { export const useGetRagDocuments = ( dataSourceId?: string, - summarizationModel?: string, + summarizationModel?: string ) => { return useQuery({ queryKey: [QueryKeys.getRagDocuments, { dataSourceId }], @@ -158,7 +158,7 @@ export const useGetRagDocuments = ( return false; } const nullTimestampDocuments = data.find( - (file: RagDocumentResponseType) => file.vectorUploadTimestamp === null, + (file: RagDocumentResponseType) => file.vectorUploadTimestamp === null ); let nullSummaryCreation = null; @@ -166,7 +166,7 @@ export const useGetRagDocuments = ( if (summarizationModel && summarizationModel.length > 0) { nullSummaryCreation = data.find( (file: RagDocumentResponseType) => - file.summaryCreationTimestamp === null, + file.summaryCreationTimestamp === null ); } @@ -176,10 +176,10 @@ export const useGetRagDocuments = ( }; const getRagDocuments = async ( - dataSourceId: string, + dataSourceId: string ): Promise => { return getRequest( - `${ragPath}/${paths.dataSources}/${dataSourceId}/${paths.files}`, + `${ragPath}/${paths.dataSources}/${dataSourceId}/${paths.files}` ); }; @@ -196,13 +196,13 @@ export const useDeleteDocumentMutation = ({ }; export const deleteDocumentMutation = async ({ - id, + documentId, dataSourceId, }: { - id: number; + documentId: string; dataSourceId: string; }): Promise => { await deleteRequest( - `${ragPath}/${paths.dataSources}/${dataSourceId}/${paths.files}/${id.toString()}`, + `${ragPath}/${paths.dataSources}/${dataSourceId}/${paths.files}/${documentId}` ); }; diff --git a/ui/src/pages/DataSources/ManageTab/UploadedFilesTable.tsx b/ui/src/pages/DataSources/ManageTab/UploadedFilesTable.tsx index a6dae7795..d4ff6d48f 100644 --- a/ui/src/pages/DataSources/ManageTab/UploadedFilesTable.tsx +++ b/ui/src/pages/DataSources/ManageTab/UploadedFilesTable.tsx @@ -45,7 +45,7 @@ import { Tooltip, Typography, } from "antd"; -import Icon, { DeleteOutlined } from "@ant-design/icons"; +import Icon, { DeleteOutlined, DownloadOutlined } from "@ant-design/icons"; import { RagDocumentResponseType, useDeleteDocumentMutation, @@ -57,18 +57,19 @@ import AiAssistantIcon from "src/cuix/icons/AiAssistantIcon"; import { useState } from "react"; import messageQueue from "src/utils/messageQueue.ts"; import { useQueryClient } from "@tanstack/react-query"; -import { QueryKeys } from "src/api/utils.ts"; +import { paths, QueryKeys, ragPath } from "src/api/utils.ts"; import useModal from "src/utils/useModal.ts"; import { cdlWhite } from "src/cuix/variables.ts"; import ReadyColumn from "pages/DataSources/ManageTab/ReadyColumn.tsx"; import SummaryColumn from "pages/DataSources/ManageTab/SummaryColumn.tsx"; import { ColumnsType } from "antd/es/table"; +import { downloadFile } from "src/utils/downloadFile.ts"; const columns = ( dataSourceId: string, handleDeleteFile: (document: RagDocumentResponseType) => void, simpleColumns: boolean, - summarizationModel?: string, + summarizationModel?: string ): TableProps["columns"] => { let columns: ColumnsType = [ { @@ -152,14 +153,26 @@ const columns = ( { title: "Actions", render: (_, record) => { + const handleDownloadFile = () => { + const url = `${ragPath}/${paths.dataSources}/${record.dataSourceId.toString()}/${paths.files}/${record.documentId}/download`; + void downloadFile(url, record.filename, { pageNumber: "2" }); + }; + return ( -