Skip to content

Instantly share code, notes, and snippets.

@kojilin
Last active August 15, 2021 22:56
Show Gist options
  • Select an option

  • Save kojilin/575fdeb752a22e6d6a328c5ce545d4a1 to your computer and use it in GitHub Desktop.

Select an option

Save kojilin/575fdeb752a22e6d6a328c5ce545d4a1 to your computer and use it in GitHub Desktop.
import java.io.BufferedOutputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import javax.annotation.Nullable;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import com.linecorp.armeria.common.CommonPools;
import com.linecorp.armeria.common.HttpData;
import com.linecorp.armeria.common.HttpHeaders;
import com.linecorp.armeria.common.MediaType;
import com.linecorp.armeria.common.QueryParams;
import com.linecorp.armeria.common.QueryParamsBuilder;
import com.linecorp.armeria.common.RequestContext;
import com.linecorp.armeria.internal.common.HttpObjectAggregator;
public class MultipartCollector implements Subscriber<BodyPart> {
private final CompletableFuture<CustomAggregatedMultipart> future = new CompletableFuture<>();
@Nullable
private Subscription multipartSubscription;
private final QueryParamsBuilder queryParamsBuilder = QueryParams.builder();
private final Map<String, List<Path>> files = new HashMap<>();
public CompletableFuture<CustomAggregatedMultipart> future() {
return future;
}
@Override
public void onSubscribe(Subscription s) {
multipartSubscription = s;
multipartSubscription.request(1);
}
@Override
public void onNext(BodyPart bodyPart) {
if (bodyPart.filename() == null) {
final CompletableFuture<AggregatedBodyPart> future = new CompletableFuture<>();
bodyPart.content().subscribe(
new HttpObjectAggregator<AggregatedBodyPart>(future, null) {
@Override
protected void onHeaders(HttpHeaders headers) {
}
@Override
protected AggregatedBodyPart onSuccess(HttpData content) {
return AggregatedBodyPart.of(bodyPart.headers(), content);
}
@Override
protected void onFailure() {
}
});
future.whenComplete((aggregatedBodyPart, throwable) -> {
if (throwable != null) {
multipartSubscription.cancel();
future.completeExceptionally(throwable);
return;
}
@Nullable
final String name = aggregatedBodyPart.name();
if (name != null) {
@Nullable
final MediaType mediaType = aggregatedBodyPart.contentType();
final Charset charset = mediaType == null
? StandardCharsets.US_ASCII
: mediaType.charset(StandardCharsets.US_ASCII);
queryParamsBuilder.add(name, aggregatedBodyPart.content(charset));
}
multipartSubscription.request(1);
});
return;
}
try {
final Path path = Files.createTempFile("armeria", "tmp");
files.compute(bodyPart.filename(), (s, paths) -> {
if (paths == null) {
paths = new ArrayList<>();
}
paths.add(path);
return paths;
});
// Looks like onNext doesn't have implicit context? May need to keep & pass specific evenloop to subscribe.
// RequestContext#evenloop#withoutContext.
bodyPart.content().subscribe(new Subscriber<HttpData>() {
@Nullable
private Subscription bodyPartSubscription;
@Nullable
BufferedOutputStream fileWriter;
CompletableFuture<Void> fileWriterFuture = CompletableFuture.completedFuture(null);
@Override
public void onSubscribe(Subscription s) {
assert multipartSubscription != null;
bodyPartSubscription = s;
try {
fileWriter = new BufferedOutputStream(new FileOutputStream(path.toFile()));
} catch (IOException e) {
multipartSubscription.cancel();
future.completeExceptionally(e);
return;
}
bodyPartSubscription.request(1);
}
@Override
public void onNext(HttpData data) {
assert fileWriter != null;
assert bodyPartSubscription != null;
assert multipartSubscription != null;
fileWriterFuture = fileWriterFuture.thenAcceptAsync(unused -> {
try {
fileWriter.write(data.array());
bodyPartSubscription.request(1);
} catch (IOException e) {
bodyPartSubscription.cancel();
multipartSubscription.cancel();
future.completeExceptionally(e);
}
}, CommonPools.blockingTaskExecutor());
}
@Override
public void onError(Throwable t) {
assert fileWriter != null;
assert multipartSubscription != null;
fileWriterFuture = fileWriterFuture.thenAcceptAsync(unused -> {
try {
fileWriter.close();
} catch (IOException e) {
multipartSubscription.cancel();
future.completeExceptionally(e);
}
}, CommonPools.blockingTaskExecutor());
}
@Override
public void onComplete() {
assert fileWriter != null;
assert multipartSubscription != null;
fileWriterFuture = fileWriterFuture.thenAcceptAsync(unused -> {
try {
fileWriter.close();
multipartSubscription.request(1);
} catch (IOException e) {
multipartSubscription.cancel();
future.completeExceptionally(e);
}
}, CommonPools.blockingTaskExecutor());
}
});
} catch (IOException e) {
multipartSubscription.cancel();
future.completeExceptionally(e);
}
}
@Override
public void onError(Throwable t) {
future.completeExceptionally(t);
}
@Override
public void onComplete() {
if (future.isDone()) {
return;
}
future.complete(new CustomAggregatedMultipart(queryParamsBuilder.build(), files));
}
static class CustomAggregatedMultipart {
private final QueryParams queryParams;
private final Map<String, List<Path>> files;
public CustomAggregatedMultipart(QueryParams queryParams,
Map<String, List<Path>> files) {
this.queryParams = queryParams;
this.files = files;
}
public QueryParams getQueryParams() {
return queryParams;
}
public Map<String, List<Path>> getFiles() {
return files;
}
}
}
import static org.assertj.core.api.Assertions.assertThat;
import java.io.File;
import java.net.MalformedURLException;
import java.util.concurrent.CompletableFuture;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.FileUrlResource;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.client.RestTemplate;
import com.linecorp.armeria.common.HttpObject;
import com.linecorp.armeria.common.HttpResponse;
import com.linecorp.armeria.common.RequestContext;
import com.linecorp.armeria.server.ServerBuilder;
import com.linecorp.armeria.testing.junit5.server.ServerExtension;
class MultipartCollectorTest {
@RegisterExtension
static ServerExtension server = new ServerExtension() {
@Override
protected void configure(ServerBuilder sb) throws Exception {
sb.service("/multipart/file", (ctx, req) -> {
final MultipartCollector multipartCollector = new MultipartCollector();
Multipart.from(req).bodyParts().subscribe(multipartCollector);
return HttpResponse.from(
multipartCollector.future().thenApply(aggregated -> HttpResponse.of(
aggregated.getQueryParams().toQueryString() + "/" + aggregated.getFiles())));
});
}
};
@Test
void multipartFile() throws MalformedURLException {
final HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.MULTIPART_FORM_DATA);
final ClassPathResource file = new ClassPathResource("test.txt");
final MultiValueMap<String, Object> body = new LinkedMultiValueMap<>();
body.add("file1", new FileUrlResource("/home/user/Downloads/file.pdf"));
body.add("file2", file);
body.add("file3", file);
body.add("file3", file);
body.add("foo", "bar");
body.add("foo", "実装");
body.add("qoo", "hoge");
final HttpEntity<MultiValueMap<String, Object>> requestEntity = new HttpEntity<>(body, headers);
final RestTemplate restTemplate = new RestTemplate();
final ResponseEntity<String> response =
restTemplate.postForEntity(server.httpUri().resolve("/multipart/file"), requestEntity,
String.class);
assertThat(response.getBody()).isEqualTo("test.txt/file/Hello!");
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment