Skip to content

Implement download file link #289

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Aug 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,11 @@

package com.cloudera.cai.rag.configuration;

import com.cloudera.cai.rag.files.FileSystemRagFileDownloader;
import com.cloudera.cai.rag.files.FileSystemRagFileUploader;
import com.cloudera.cai.rag.files.RagFileDownloader;
import com.cloudera.cai.rag.files.RagFileUploader;
import com.cloudera.cai.rag.files.S3RagFileDownloader;
import com.cloudera.cai.rag.files.S3RagFileUploader;
import com.cloudera.cai.util.reconcilers.ReconcilerConfig;
import com.cloudera.cai.util.s3.AmazonS3Client;
Expand Down Expand Up @@ -162,6 +165,15 @@ public RagFileUploader ragFileUploader(S3Config configuration) {
return new S3RagFileUploader(s3Client, configuration.getBucketName());
}

@Bean
public RagFileDownloader ragFileDownloader(S3Config configuration) {
if (configuration.getBucketName().isEmpty()) {
return new FileSystemRagFileDownloader();
}
AmazonS3Client s3Client = new AmazonS3Client(configuration);
return new S3RagFileDownloader(s3Client, configuration.getBucketName());
}

public static String getLlmServiceUrl() {
var llmServiceUrl =
Optional.ofNullable(System.getenv("LLM_SERVICE_URL"))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package com.cloudera.cai.rag.files;

import com.cloudera.cai.util.exceptions.NotFound;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import lombok.extern.slf4j.Slf4j;

@Slf4j
public class FileSystemRagFileDownloader implements RagFileDownloader {
private static final String FILE_STORAGE_ROOT = fileStoragePath();

private static String fileStoragePath() {
var fileStoragePath = System.getenv("RAG_DATABASES_DIR") + "/file_storage";
log.info("configured with fileStoragePath = {}", fileStoragePath);
return fileStoragePath;
}

@Override
public InputStream openStream(String s3Path) throws NotFound {
try {
Path filePath = Path.of(FILE_STORAGE_ROOT, s3Path);
if (!Files.exists(filePath)) {
throw new NotFound("no document found with storage path: " + s3Path);
}
return Files.newInputStream(filePath);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,12 @@
import java.util.List;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;
import org.springframework.web.servlet.mvc.method.annotation.StreamingResponseBody;

@RestController
@Slf4j
Expand Down Expand Up @@ -82,8 +86,27 @@ public List<Types.RagDocument> getRagDocuments(@PathVariable Long dataSourceId)
return ragFileService.getRagDocuments(dataSourceId);
}

@DeleteMapping(value = "/dataSources/{dataSourceId}/files/{id}")
public void deleteRagFile(@PathVariable Long id, @PathVariable Long dataSourceId) {
ragFileService.deleteRagFile(id, dataSourceId);
@DeleteMapping(value = "/dataSources/{dataSourceId}/files/{documentId}")
public void deleteRagFile(@PathVariable Long dataSourceId, @PathVariable String documentId) {
ragFileService.deleteRagFileByDocumentId(documentId, dataSourceId);
}

@GetMapping(value = "/dataSources/{dataSourceId}/files/{documentId}/download")
public ResponseEntity<StreamingResponseBody> downloadRagDocument(
@PathVariable Long dataSourceId, @PathVariable String documentId) {
var downloaded = ragFileService.downloadDocumentByDocumentId(dataSourceId, documentId);
String filename = downloaded.filename();

StreamingResponseBody body =
outputStream -> {
try (var in = downloaded.stream()) {
in.transferTo(outputStream);
}
};

return ResponseEntity.ok()
.header(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"" + filename + "\"")
.contentType(MediaType.APPLICATION_OCTET_STREAM)
.body(body);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package com.cloudera.cai.rag.files;

import com.cloudera.cai.util.exceptions.NotFound;
import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.util.HashMap;
import java.util.Map;

/** Downloader abstraction to open a streaming InputStream for a stored document. */
public interface RagFileDownloader {
/**
* Opens a streaming InputStream for the provided storage path. The caller is responsible for
* closing the returned stream.
*/
InputStream openStream(String s3Path) throws NotFound;

/**
* Test double that serves InputStreams from an in-memory byte[] map. Useful for unit tests
* without using a mocking framework.
*/
static RagFileDownloader createNull() {
return createNull(Map.of());
}

static RagFileDownloader createNull(Map<String, byte[]> pathToBytes) {
Map<String, byte[]> backing = new HashMap<>(pathToBytes);
return s3Path -> {
byte[] data = backing.get(s3Path);
if (data == null) {
throw new NotFound("no document found with storage path: " + s3Path);
}
return new ByteArrayInputStream(data);
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ public class RagFileService {
private final RagDataSourceRepository ragDataSourceRepository;
private final RagFileDeleteReconciler ragFileDeleteReconciler;
private final RagFileSummaryReconciler ragFileSummaryReconciler;
private final RagFileDownloader ragFileDownloader;

@Autowired
public RagFileService(
Expand All @@ -83,7 +84,8 @@ public RagFileService(
@Qualifier("s3BucketPrefix") String s3PathPrefix,
RagDataSourceRepository ragDataSourceRepository,
RagFileDeleteReconciler ragFileDeleteReconciler,
RagFileSummaryReconciler ragFileSummaryReconciler) {
RagFileSummaryReconciler ragFileSummaryReconciler,
RagFileDownloader ragFileDownloader) {
this.idGenerator = idGenerator;
this.ragFileRepository = ragFileRepository;
this.ragFileUploader = ragFileUploader;
Expand All @@ -92,6 +94,7 @@ public RagFileService(
this.ragDataSourceRepository = ragDataSourceRepository;
this.ragFileDeleteReconciler = ragFileDeleteReconciler;
this.ragFileSummaryReconciler = ragFileSummaryReconciler;
this.ragFileDownloader = ragFileDownloader;
}

public List<RagDocumentMetadata> saveRagFile(
Expand Down Expand Up @@ -201,19 +204,47 @@ private static String validateFilename(String originalFilename) {
return originalFilename;
}

public void deleteRagFile(Long id, Long dataSourceId) {
var document = ragFileRepository.getRagDocumentById(id);
public void deleteRagFileByDocumentId(String documentId, Long dataSourceId) {
var document = ragFileRepository.findDocumentByDocumentId(documentId);
if (!document.dataSourceId().equals(dataSourceId)) {
throw new NotFound("Document with id " + id + " not found for dataSourceId: " + dataSourceId);
throw new NotFound(
"Document with documentId "
+ documentId
+ " not found for dataSourceId: "
+ dataSourceId);
}
ragFileRepository.deleteById(id);
ragFileRepository.deleteById(document.id());
ragFileDeleteReconciler.submit(document);
}

public List<RagDocument> getRagDocuments(Long dataSourceId) {
return ragFileRepository.getRagDocuments(dataSourceId);
}

public record DownloadedDocument(String filename, String s3Path, java.io.InputStream stream) {}

public DownloadedDocument downloadDocument(Long dataSourceId, Long id) {
var document = ragFileRepository.getRagDocumentById(id);
if (!document.dataSourceId().equals(dataSourceId)) {
throw new NotFound("Document with id " + id + " not found for dataSourceId: " + dataSourceId);
}
var inputStream = ragFileDownloader.openStream(document.s3Path());
return new DownloadedDocument(document.filename(), document.s3Path(), inputStream);
}

public DownloadedDocument downloadDocumentByDocumentId(Long dataSourceId, String documentId) {
var document = ragFileRepository.findDocumentByDocumentId(documentId);
if (!document.dataSourceId().equals(dataSourceId)) {
throw new NotFound(
"Document with documentId "
+ documentId
+ " not found for dataSourceId: "
+ dataSourceId);
}
var inputStream = ragFileDownloader.openStream(document.s3Path());
return new DownloadedDocument(document.filename(), document.s3Path(), inputStream);
}

public record MultipartUploadableFile(MultipartFile file) implements UploadableFile {

@Override
Expand Down Expand Up @@ -262,6 +293,7 @@ public static RagFileService createNull(String... dummyIds) {
"prefix",
RagDataSourceRepository.createNull(),
RagFileDeleteReconciler.createNull(),
RagFileSummaryReconciler.createNull());
RagFileSummaryReconciler.createNull(),
RagFileDownloader.createNull());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package com.cloudera.cai.rag.files;

import com.cloudera.cai.util.exceptions.NotFound;
import com.cloudera.cai.util.s3.AmazonS3Client;
import com.cloudera.cai.util.s3.RefCountedS3Client;
import java.io.FilterInputStream;
import java.io.IOException;
import java.io.InputStream;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Qualifier;
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
import software.amazon.awssdk.services.s3.model.S3Exception;

@Slf4j
public class S3RagFileDownloader implements RagFileDownloader {
private final AmazonS3Client s3Client;
private final String bucketName;

public S3RagFileDownloader(
AmazonS3Client s3Client, @Qualifier("s3BucketName") String s3BucketName) {
this.s3Client = s3Client;
this.bucketName = s3BucketName;
}

@Override
public InputStream openStream(String s3Path) throws NotFound {
log.info("Downloading file from S3: {}", s3Path);
GetObjectRequest request = GetObjectRequest.builder().bucket(bucketName).key(s3Path).build();
RefCountedS3Client client = s3Client.getRefCountedClient();
try {
InputStream inner = client.getClient().getObject(request);
return new ClosingInputStream(inner, client);
} catch (S3Exception e) {
if (e.statusCode() == 404) {
client.close();
throw new NotFound("no document found with storage path: " + s3Path);
}
client.close();
throw e;
}
}

private static class ClosingInputStream extends FilterInputStream {
private final RefCountedS3Client client;

protected ClosingInputStream(InputStream in, RefCountedS3Client client) {
super(in);
this.client = client;
}

@Override
public void close() throws IOException {
try {
super.close();
} finally {
client.close();
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ private S3Client buildAccessKeyClient(Region awsRegion) {
}

public RefCountedS3Client getRefCountedClient() {
return new RefCountedS3Client(awsCredentials, s3Client, asyncClient, referenceCounter);
return new RefCountedS3Client(s3Client, asyncClient, referenceCounter);
}

private void initializeWithWebIdentity(Region awsRegion) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,23 +40,16 @@

import java.util.concurrent.atomic.AtomicInteger;
import lombok.extern.slf4j.Slf4j;
import software.amazon.awssdk.auth.credentials.AwsCredentials;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.S3Client;

@Slf4j
public class RefCountedS3Client implements AutoCloseable {
private final AwsCredentials credentials;
private final S3Client s3Client;
private final S3AsyncClient asyncClient;
private final AtomicInteger referenceCounter;

RefCountedS3Client(
AwsCredentials credentials,
S3Client client,
S3AsyncClient asyncClient,
AtomicInteger counter) {
this.credentials = credentials;
RefCountedS3Client(S3Client client, S3AsyncClient asyncClient, AtomicInteger counter) {
this.s3Client = client;
this.asyncClient = asyncClient;
this.referenceCounter = counter;
Expand All @@ -67,14 +60,6 @@ public S3Client getClient() {
return s3Client;
}

public S3AsyncClient getAsyncClient() {
return asyncClient;
}

public AwsCredentials getCredentials() {
return credentials;
}

@Override
public void close() {
if (referenceCounter.decrementAndGet() == 0) {
Expand Down
Loading
Loading