Skip to content

Instantly share code, notes, and snippets.

@sshaaf
Last active September 25, 2025 22:58
Show Gist options
  • Select an option

  • Save sshaaf/a364cce2b85684783458f9b3a2d594c9 to your computer and use it in GitHub Desktop.

Select an option

Save sshaaf/a364cce2b85684783458f9b3a2d594c9 to your computer and use it in GitHub Desktop.
import java.io.*;
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import java.nio.channels.FileChannel;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardOpenOption;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.LongAdder;
import java.util.stream.Stream;
import jdk.incubator.vector.ByteVector;
import jdk.incubator.vector.VectorOperators;
import jdk.incubator.vector.VectorSpecies;
import java.nio.ByteOrder;
/**
* A command-line application to find the top N most frequent words from a large
* collection of files, with detailed performance metrics.
*
* This version is inspired by the "One Billion Row Challenge" (1BRC) and uses
* advanced techniques like memory-mapped files and the Java Vector API for
* maximum performance.
*/
public class LargeScaleVectors {
// Use the largest vector size our CPU supports for maximum SIMD throughput.
private static final VectorSpecies<Byte> SPECIES = ByteVector.SPECIES_PREFERRED;
/**
* Inner class to hold all performance-related metrics.
*/
private static class PerformanceMetrics {
long totalTimeMs;
long preProcessingTimeMs;
long mapPhaseTimeMs;
long mergePhaseTimeMs;
long topNPhaseTimeMs;
final AtomicLong totalBytesRead = new AtomicLong(0);
final AtomicLong totalBytesSpilled = new AtomicLong(0);
final AtomicLong finalFileBytes = new AtomicLong(0);
final AtomicLong peakMemoryUsedBytes = new AtomicLong(0);
final AtomicLong totalUniqueWords = new AtomicLong(0);
long spillFileCount = 0;
public void printReport() {
System.out.println("\n--- PERFORMANCE METRICS ------------------------");
System.out.printf(" Total Execution Time : %d ms%n", totalTimeMs);
System.out.println("-------------------------------------------------");
System.out.println("PHASE TIMINGS:");
System.out.printf(" - Pre-processing : %d ms%n", preProcessingTimeMs);
System.out.printf(" - Map & Spill Phase : %d ms%n", mapPhaseTimeMs);
System.out.printf(" - Merge Phase : %d ms%n", mergePhaseTimeMs);
System.out.printf(" - Top-N Phase : %d ms%n", topNPhaseTimeMs);
System.out.println("MEMORY & DATA:");
System.out.printf(" - Peak Memory Used : %s%n", formatBytes(peakMemoryUsedBytes.get()));
System.out.printf(" - Unique Words Found: %,d%n", totalUniqueWords.get());
System.out.println("I/O STATISTICS:");
System.out.printf(" - Total Data Read : %s%n", formatBytes(totalBytesRead.get()));
System.out.printf(" - Data Spilled : %s (%d spill files)%n", formatBytes(totalBytesSpilled.get()), spillFileCount);
System.out.printf(" - Final Output Size : %s%n", formatBytes(finalFileBytes.get()));
if (mapPhaseTimeMs > 0) {
double throughput = (totalBytesRead.get() / (1024.0 * 1024.0)) / (mapPhaseTimeMs / 1000.0);
System.out.println("THROUGHPUT:");
System.out.printf(" - Map Phase : %.2f MB/s%n", throughput);
}
System.out.println("-------------------------------------------------");
}
void updatePeakMemory() {
long usedMemory = Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory();
peakMemoryUsedBytes.getAndAccumulate(usedMemory, Math::max);
}
private static String formatBytes(long bytes) {
if (bytes < 1024) return bytes + " B";
int exp = (int) (Math.log(bytes) / Math.log(1024));
String pre = "KMGTPE".charAt(exp - 1) + "";
return String.format("%.2f %sB", bytes / Math.pow(1024, exp), pre);
}
}
public static void main(String[] args) throws IOException, InterruptedException {
if (args.length != 2) {
System.err.println("Usage: java LargeScaleTopWords <N> <root-directory-path>");
System.exit(1);
}
int topN = Integer.parseInt(args[0]);
Path rootDir = Paths.get(args[1]);
if (!Files.isDirectory(rootDir)) {
if (Files.isRegularFile(rootDir)) {
System.err.println("Error: Provided path is a file, but this implementation requires a directory path.");
System.err.println("Please specify a directory containing text files to process.");
System.err.println("If you want to process a single file, use VectorsWith1BigFile or WithMappedByteBuffer instead.");
System.err.println("Usage: java LargeScaleTopWords <N> <root-directory-path>");
} else {
System.err.println("Error: Provided path is not a valid directory: " + rootDir);
}
System.exit(1);
}
Path tempDir = Files.createTempDirectory("wordcount-spill");
Path finalOutputFile = tempDir.resolve("final-counts.txt");
PerformanceMetrics metrics = new PerformanceMetrics();
long totalStartTime = System.currentTimeMillis();
try {
// --- NEW PHASE: Combine all files into one large temp file ---
System.out.println("\n--- Pre-processing: Combining input files ---");
long preProcessingStartTime = System.currentTimeMillis();
Path combinedFile = tempDir.resolve("combined-input.txt");
combineFiles(rootDir, combinedFile, metrics);
metrics.preProcessingTimeMs = System.currentTimeMillis() - preProcessingStartTime;
System.out.printf("File combination complete in %d ms. Total size: %s%n", metrics.preProcessingTimeMs, PerformanceMetrics.formatBytes(metrics.totalBytesRead.get()));
// --- PHASE 1: MAP AND SPILL TO DISK ---
System.out.println("\n--- PHASE 1: Mapping and Spilling to Disk (1BRC Style) ---");
long mapStartTime = System.currentTimeMillis();
List<Path> tempFiles = mapAndSpill(combinedFile, tempDir, metrics);
metrics.mapPhaseTimeMs = System.currentTimeMillis() - mapStartTime;
metrics.spillFileCount = tempFiles.size();
System.out.printf("Map phase complete in %d ms. Created %d spill files.%n", metrics.mapPhaseTimeMs, metrics.spillFileCount);
// --- PHASE 2: MERGE AND REDUCE ---
System.out.println("\n--- PHASE 2: Merging Spill Files ---");
long mergeStartTime = System.currentTimeMillis();
mergeAndReduce(tempFiles, finalOutputFile, metrics);
metrics.mergePhaseTimeMs = System.currentTimeMillis() - mergeStartTime;
System.out.printf("Merge phase complete in %d ms. Final counts at: %s%n", metrics.mergePhaseTimeMs, finalOutputFile);
// --- PHASE 3: FIND TOP N ---
System.out.println("\n--- PHASE 3: Finding Top " + topN + " Words ---");
long topNStartTime = System.currentTimeMillis();
List<Map.Entry<String, Long>> topWords = findTopN(finalOutputFile, topN, metrics);
metrics.topNPhaseTimeMs = System.currentTimeMillis() - topNStartTime;
System.out.printf("Top-N phase complete in %d ms.%n", metrics.topNPhaseTimeMs);
// --- FINAL RESULTS ---
System.out.println("\n-------------------------------------------------");
System.out.println("Top " + topN + " most frequent words:");
for (Map.Entry<String, Long> entry : topWords) {
System.out.printf("%-20s : %d%n", entry.getKey(), entry.getValue());
}
System.out.println("-------------------------------------------------");
} finally {
try (Stream<Path> walk = Files.walk(tempDir)) {
walk.sorted(Comparator.reverseOrder()).forEach(path -> {
try { Files.delete(path); } catch (IOException e) { /* ignore */ }
});
}
System.out.println("\nCleaned up temporary directory: " + tempDir);
}
metrics.totalTimeMs = System.currentTimeMillis() - totalStartTime;
metrics.printReport();
}
private static void combineFiles(Path rootDir, Path combinedFile, PerformanceMetrics metrics) throws IOException {
try (OutputStream out = new BufferedOutputStream(Files.newOutputStream(combinedFile));
Stream<Path> paths = Files.walk(rootDir)) {
paths.filter(Files::isRegularFile).forEach(path -> {
try {
Files.copy(path, out);
// Add a newline to ensure words aren't merged across file boundaries
out.write('\n');
} catch (IOException e) {
throw new UncheckedIOException(e);
}
});
}
metrics.totalBytesRead.set(Files.size(combinedFile));
}
private static List<Path> mapAndSpill(Path largeFile, Path tempDir, PerformanceMetrics metrics) throws IOException, InterruptedException {
final List<Path> tempFiles = new CopyOnWriteArrayList<>();
final AtomicInteger spillFileCounter = new AtomicInteger(0);
int numThreads = Runtime.getRuntime().availableProcessors();
long fileSize = Files.size(largeFile);
long chunkSize = fileSize / numThreads;
// Use Arena for safe, automatic management of off-heap memory.
try (FileChannel fileChannel = FileChannel.open(largeFile, StandardOpenOption.READ);
// *** FIX: Use a shared Arena to allow access from multiple threads. ***
Arena arena = Arena.ofShared()) {
// Map the entire file into a MemorySegment using the FileChannel.
MemorySegment fileSegment = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, arena);
try (var scope = new StructuredTaskScope.ShutdownOnFailure()) {
for (int i = 0; i < numThreads; i++) {
long start = i * chunkSize;
long end = (i == numThreads - 1) ? fileSize : start + chunkSize;
if (start > 0) {
start = findWordBoundary(fileSegment, start);
}
final long finalStart = start;
final long finalEnd = end;
scope.fork(() -> {
Map<String, LongAdder> localMap = new HashMap<>(4096);
// Each thread works on a slice of the main memory segment.
MemorySegment chunkSegment = fileSegment.asSlice(finalStart, finalEnd - finalStart);
parseChunk(chunkSegment, localMap);
if (!localMap.isEmpty()) {
Path spillFile = tempDir.resolve("spill-" + spillFileCounter.getAndIncrement() + ".txt");
spillLocalMap(localMap, spillFile, metrics);
tempFiles.add(spillFile);
}
return null;
});
}
scope.join();
scope.throwIfFailed();
} catch (ExecutionException e) {
throw new IOException("A file processing task failed", e.getCause());
}
}
return tempFiles;
}
private static long findWordBoundary(MemorySegment segment, long position) {
long searchPos = Math.max(0, position - 256);
for (long i = position - 1; i >= searchPos; i--) {
byte b = segment.get(ValueLayout.JAVA_BYTE, i);
if (b < 'a' || b > 'z') { // Simple check for non-lowercase-letter
return i + 1;
}
}
return searchPos;
}
private static void parseChunk(MemorySegment segment, Map<String, LongAdder> localMap) {
long limit = segment.byteSize();
long position = 0;
while (position < limit) {
long wordStart = position;
// Use Vector API to quickly find the next non-letter character
long vectorLimit = limit - SPECIES.length();
while (position < vectorLimit) {
ByteVector vector = ByteVector.fromMemorySegment(SPECIES, segment, position, ByteOrder.nativeOrder());
var mask_lt_a = vector.compare(VectorOperators.LT, (byte) 'a');
var mask_gt_z = vector.compare(VectorOperators.GT, (byte) 'z');
var nonLetterMask = mask_lt_a.or(mask_gt_z);
if (nonLetterMask.anyTrue()) {
int firstNonLetter = nonLetterMask.firstTrue();
position += firstNonLetter;
break;
}
position += SPECIES.length();
}
// Scalar loop for the rest of the buffer
while (position < limit) {
byte b = segment.get(ValueLayout.JAVA_BYTE, position);
if (b >= 'a' && b <= 'z') {
position++;
} else {
break;
}
}
long wordLen = position - wordStart;
if (wordLen > 0) {
// Extract the word bytes from the segment and create a String.
String word = new String(segment.asSlice(wordStart, wordLen).toArray(ValueLayout.JAVA_BYTE), StandardCharsets.UTF_8);
localMap.computeIfAbsent(word, k -> new LongAdder()).increment();
}
// Skip over delimiters
while (position < limit) {
byte b = segment.get(ValueLayout.JAVA_BYTE, position);
if (b < 'a' || b > 'z') {
position++;
} else {
break;
}
}
}
}
private static void spillLocalMap(Map<String, LongAdder> localMap, Path spillFile, PerformanceMetrics metrics) throws IOException {
metrics.updatePeakMemory();
List<Map.Entry<String, LongAdder>> sortedEntries = new ArrayList<>(localMap.entrySet());
sortedEntries.sort(Map.Entry.comparingByKey());
try (BufferedWriter writer = Files.newBufferedWriter(spillFile, StandardCharsets.UTF_8)) {
for (Map.Entry<String, LongAdder> entry : sortedEntries) {
writer.write(entry.getKey() + "\t" + entry.getValue().sum());
writer.newLine();
}
}
metrics.totalBytesSpilled.addAndGet(Files.size(spillFile));
}
private static void mergeAndReduce(List<Path> tempFiles, Path finalOutputFile, PerformanceMetrics metrics) throws IOException {
List<BufferedReader> readers = new ArrayList<>();
PriorityQueue<WordFileEntry> pq = new PriorityQueue<>(Comparator.comparing(e -> e.word));
try {
for (Path file : tempFiles) {
BufferedReader reader = Files.newBufferedReader(file, StandardCharsets.UTF_8);
readers.add(reader);
String line = reader.readLine();
if (line != null) {
pq.add(new WordFileEntry(line, reader));
}
}
try (BufferedWriter writer = Files.newBufferedWriter(finalOutputFile, StandardCharsets.UTF_8)) {
String currentWord = null;
long currentCount = 0;
while (!pq.isEmpty()) {
WordFileEntry entry = pq.poll();
if (currentWord == null) currentWord = entry.word;
if (!entry.word.equals(currentWord)) {
writer.write(currentWord + "\t" + currentCount);
writer.newLine();
currentWord = entry.word;
currentCount = 0;
}
currentCount += entry.count;
String nextLine = entry.reader.readLine();
if (nextLine != null) {
pq.add(new WordFileEntry(nextLine, entry.reader));
}
}
if (currentWord != null) {
writer.write(currentWord + "\t" + currentCount);
writer.newLine();
}
}
metrics.finalFileBytes.set(Files.size(finalOutputFile));
} finally {
for (BufferedReader reader : readers) {
try { reader.close(); } catch (IOException e) { /* ignore */ }
}
}
}
private static List<Map.Entry<String, Long>> findTopN(Path finalFile, int n, PerformanceMetrics metrics) throws IOException {
PriorityQueue<Map.Entry<String, Long>> topNHeap = new PriorityQueue<>(Map.Entry.comparingByValue());
final AtomicLong uniqueWordCounter = new AtomicLong(0);
try (Stream<String> lines = Files.lines(finalFile, StandardCharsets.UTF_8)) {
lines.forEach(line -> {
uniqueWordCounter.incrementAndGet();
String[] parts = line.split("\t");
if (parts.length == 2) {
String word = parts[0];
long count = Long.parseLong(parts[1]);
if (topNHeap.size() < n) {
topNHeap.add(new AbstractMap.SimpleEntry<>(word, count));
} else if (count > topNHeap.peek().getValue()) {
topNHeap.poll();
topNHeap.add(new AbstractMap.SimpleEntry<>(word, count));
}
}
});
}
metrics.totalUniqueWords.set(uniqueWordCounter.get());
List<Map.Entry<String, Long>> result = new ArrayList<>(topNHeap);
result.sort(Map.Entry.comparingByValue(Comparator.reverseOrder()));
return result;
}
private static class WordFileEntry {
final String word;
final long count;
final BufferedReader reader;
WordFileEntry(String line, BufferedReader reader) {
String[] parts = line.split("\t");
this.word = parts[0];
this.count = Long.parseLong(parts[1]);
this.reader = reader;
}
}
}
import java.io.*;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.LongAdder;
import java.util.stream.Stream;
/**
* A command-line application to find the top N most frequent words from a large
* collection of files, with detailed performance metrics.
*
* This version uses a fully decentralized map phase where each thread spills
* its own local map to a unique file, eliminating all shared state.
*/
public class LargeScaleTopWords {
/**
* A flyweight, reusable object that represents a word as a slice of a char
* array. This avoids creating a new String object for every word lookup,
* drastically reducing object allocation and GC pressure. It implements
* CharSequence so it can be compared against the String keys in the map.
*/
private static final class WordView implements CharSequence {
private char[] buffer;
private int offset;
private int length;
private int hash;
public WordView set(char[] buffer, int offset, int length) {
this.buffer = buffer;
this.offset = offset;
this.length = length;
this.hash = 0; // Reset hash so it's recalculated
return this;
}
@Override
public int length() { return length; }
@Override
public char charAt(int index) {
if (index < 0 || index >= length) throw new StringIndexOutOfBoundsException(index);
return buffer[offset + index];
}
@Override
public CharSequence subSequence(int start, int end) {
throw new UnsupportedOperationException();
}
@Override
public int hashCode() {
int h = hash;
if (h == 0 && length > 0) {
for (int i = 0; i < length; i++) {
h = 31 * h + buffer[offset + i];
}
hash = h;
}
return h;
}
@Override
public boolean equals(Object anObject) {
if (this == anObject) return true;
if (anObject instanceof CharSequence) {
CharSequence other = (CharSequence) anObject;
if (length != other.length()) return false;
for (int i = 0; i < length; i++) {
if (buffer[offset + i] != other.charAt(i)) return false;
}
return true;
}
return false;
}
@Override
public String toString() {
return new String(buffer, offset, length);
}
}
/**
* Inner class to hold all performance-related metrics.
*/
private static class PerformanceMetrics {
long totalTimeMs;
long mapPhaseTimeMs;
long mergePhaseTimeMs;
long topNPhaseTimeMs;
final AtomicLong totalBytesRead = new AtomicLong(0);
final AtomicLong totalBytesSpilled = new AtomicLong(0);
final AtomicLong finalFileBytes = new AtomicLong(0);
final AtomicLong peakMemoryUsedBytes = new AtomicLong(0);
final AtomicLong totalUniqueWords = new AtomicLong(0);
long spillFileCount = 0;
public void printReport() {
System.out.println("\n--- PERFORMANCE METRICS ------------------------");
System.out.printf(" Total Execution Time : %d ms%n", totalTimeMs);
System.out.println("-------------------------------------------------");
System.out.println("PHASE TIMINGS:");
System.out.printf(" - Map & Spill Phase : %d ms%n", mapPhaseTimeMs);
System.out.printf(" - Merge Phase : %d ms%n", mergePhaseTimeMs);
System.out.printf(" - Top-N Phase : %d ms%n", topNPhaseTimeMs);
System.out.println("MEMORY & DATA:");
System.out.printf(" - Peak Memory Used : %s%n", formatBytes(peakMemoryUsedBytes.get()));
System.out.printf(" - Unique Words Found: %,d%n", totalUniqueWords.get());
System.out.println("I/O STATISTICS:");
System.out.printf(" - Total Data Read : %s%n", formatBytes(totalBytesRead.get()));
System.out.printf(" - Data Spilled : %s (%d spill files)%n", formatBytes(totalBytesSpilled.get()), spillFileCount);
System.out.printf(" - Final Output Size : %s%n", formatBytes(finalFileBytes.get()));
if (mapPhaseTimeMs > 0) {
double throughput = (totalBytesRead.get() / (1024.0 * 1024.0)) / (mapPhaseTimeMs / 1000.0);
System.out.println("THROUGHPUT:");
System.out.printf(" - Map Phase : %.2f MB/s%n", throughput);
}
System.out.println("-------------------------------------------------");
}
void updatePeakMemory() {
long usedMemory = Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory();
peakMemoryUsedBytes.getAndAccumulate(usedMemory, Math::max);
}
private static String formatBytes(long bytes) {
if (bytes < 1024) return bytes + " B";
int exp = (int) (Math.log(bytes) / Math.log(1024));
String pre = "KMGTPE".charAt(exp - 1) + "";
return String.format("%.2f %sB", bytes / Math.pow(1024, exp), pre);
}
}
public static void main(String[] args) throws IOException, InterruptedException {
if (args.length != 2) {
System.err.println("Usage: java LargeScaleTopWords <N> <root-directory-path>");
System.exit(1);
}
int topN = Integer.parseInt(args[0]);
Path rootDir = Paths.get(args[1]);
if (!Files.isDirectory(rootDir)) {
if (Files.isRegularFile(rootDir)) {
System.err.println("Error: Provided path is a file, but this implementation requires a directory path.");
System.err.println("Please specify a directory containing text files to process.");
System.err.println("If you want to process a single file, use VectorsWith1BigFile or WithMappedByteBuffer instead.");
System.err.println("Usage: java LargeScaleTopWords <N> <root-directory-path>");
} else {
System.err.println("Error: Provided path is not a valid directory: " + rootDir);
}
System.exit(1);
}
Path tempDir = Files.createTempDirectory("wordcount-spill");
Path finalOutputFile = tempDir.resolve("final-counts.txt");
System.out.println("Starting large-scale word frequency count.");
System.out.println("Temporary spill files will be written to: " + tempDir);
PerformanceMetrics metrics = new PerformanceMetrics();
long totalStartTime = System.currentTimeMillis();
try {
// --- PHASE 1: MAP AND SPILL TO DISK ---
System.out.println("\n--- PHASE 1: Mapping and Spilling to Disk (Streaming) ---");
long mapStartTime = System.currentTimeMillis();
List<Path> tempFiles = mapAndSpill(rootDir, tempDir, metrics);
metrics.mapPhaseTimeMs = System.currentTimeMillis() - mapStartTime;
metrics.spillFileCount = tempFiles.size();
System.out.printf("Map phase complete in %d ms. Created %d spill files.%n", metrics.mapPhaseTimeMs, metrics.spillFileCount);
// --- PHASE 2: MERGE AND REDUCE ---
System.out.println("\n--- PHASE 2: Merging Spill Files ---");
long mergeStartTime = System.currentTimeMillis();
mergeAndReduce(tempFiles, finalOutputFile, metrics);
metrics.mergePhaseTimeMs = System.currentTimeMillis() - mergeStartTime;
System.out.printf("Merge phase complete in %d ms. Final counts at: %s%n", metrics.mergePhaseTimeMs, finalOutputFile);
// --- PHASE 3: FIND TOP N ---
System.out.println("\n--- PHASE 3: Finding Top " + topN + " Words ---");
long topNStartTime = System.currentTimeMillis();
List<Map.Entry<String, Long>> topWords = findTopN(finalOutputFile, topN, metrics);
metrics.topNPhaseTimeMs = System.currentTimeMillis() - topNStartTime;
System.out.printf("Top-N phase complete in %d ms.%n", metrics.topNPhaseTimeMs);
// --- FINAL RESULTS ---
System.out.println("\n-------------------------------------------------");
System.out.println("Top " + topN + " most frequent words:");
for (Map.Entry<String, Long> entry : topWords) {
System.out.printf("%-20s : %d%n", entry.getKey(), entry.getValue());
}
System.out.println("-------------------------------------------------");
} finally {
try (Stream<Path> walk = Files.walk(tempDir)) {
walk.sorted(Comparator.reverseOrder()).forEach(path -> {
try { Files.delete(path); } catch (IOException e) { /* ignore */ }
});
}
System.out.println("\nCleaned up temporary directory: " + tempDir);
}
metrics.totalTimeMs = System.currentTimeMillis() - totalStartTime;
metrics.printReport();
}
private static List<Path> mapAndSpill(Path rootDir, Path tempDir, PerformanceMetrics metrics) throws IOException, InterruptedException {
// The list of all generated spill files.
final List<Path> tempFiles = new CopyOnWriteArrayList<>();
// A simple counter to generate unique filenames for each spill file.
final AtomicInteger spillFileCounter = new AtomicInteger(0);
try (var scope = new StructuredTaskScope.ShutdownOnFailure()) {
try (Stream<Path> paths = Files.walk(rootDir)) {
paths.filter(Files::isRegularFile).forEach(filePath -> {
scope.fork(() -> {
// Each thread gets its own private HashMap, pre-sized for efficiency.
Map<String, LongAdder> localMap = new HashMap<>(4096);
final WordView wordView = new WordView();
try {
metrics.totalBytesRead.addAndGet(Files.size(filePath));
try (BufferedReader reader = Files.newBufferedReader(filePath, StandardCharsets.UTF_8)) {
String line;
while ((line = reader.readLine()) != null) {
parseLine(line, localMap, wordView);
}
}
// *** ARCHITECTURE: Each thread spills its own local map. ***
// If the map is not empty, write it to a new, unique spill file.
if (!localMap.isEmpty()) {
Path spillFile = tempDir.resolve("spill-" + spillFileCounter.getAndIncrement() + ".txt");
spillLocalMap(localMap, spillFile, metrics);
tempFiles.add(spillFile);
}
} catch (IOException e) {
throw new UncheckedIOException(e);
}
return null;
});
});
}
scope.join();
scope.throwIfFailed();
} catch (ExecutionException e) {
throw new IOException("A file processing task failed", e.getCause());
}
// No global map means no final flush is needed here.
return tempFiles;
}
/**
* A high-performance manual parser that uses a flyweight WordView object
* to populate a thread-local map.
*/
private static void parseLine(String line, Map<String, LongAdder> localMap, WordView wordView) {
final char[] chars = line.toCharArray();
int wordStart = -1;
for (int i = 0; i < chars.length; i++) {
char c = chars[i];
if (Character.isLetter(c)) {
if (wordStart == -1) {
wordStart = i;
}
chars[i] = Character.toLowerCase(c);
} else {
if (wordStart != -1) {
processWord(localMap, wordView, chars, wordStart, i - wordStart);
wordStart = -1;
}
}
}
if (wordStart != -1) {
processWord(localMap, wordView, chars, wordStart, chars.length - wordStart);
}
}
/**
* Processes a found word slice for a non-concurrent, thread-local HashMap.
*/
private static void processWord(Map<String, LongAdder> localMap, WordView wordView, char[] buffer, int start, int len) {
wordView.set(buffer, start, len);
String word = wordView.toString();
LongAdder adder = localMap.get(word);
if (adder == null) {
adder = new LongAdder();
localMap.put(word, adder);
}
adder.increment();
}
/**
* Writes the contents of a single thread's local map to a unique spill file.
* This method is NOT synchronized as each thread writes to its own file.
*/
private static void spillLocalMap(Map<String, LongAdder> localMap, Path spillFile, PerformanceMetrics metrics) throws IOException {
metrics.updatePeakMemory();
List<Map.Entry<String, LongAdder>> sortedEntries = new ArrayList<>(localMap.entrySet());
sortedEntries.sort(Map.Entry.comparingByKey());
try (BufferedWriter writer = Files.newBufferedWriter(spillFile, StandardCharsets.UTF_8)) {
for (Map.Entry<String, LongAdder> entry : sortedEntries) {
writer.write(entry.getKey() + "\t" + entry.getValue().sum());
writer.newLine();
}
}
metrics.totalBytesSpilled.addAndGet(Files.size(spillFile));
}
private static void mergeAndReduce(List<Path> tempFiles, Path finalOutputFile, PerformanceMetrics metrics) throws IOException {
List<BufferedReader> readers = new ArrayList<>();
PriorityQueue<WordFileEntry> pq = new PriorityQueue<>(Comparator.comparing(e -> e.word));
try {
for (Path file : tempFiles) {
BufferedReader reader = Files.newBufferedReader(file, StandardCharsets.UTF_8);
readers.add(reader);
String line = reader.readLine();
if (line != null) {
pq.add(new WordFileEntry(line, reader));
}
}
try (BufferedWriter writer = Files.newBufferedWriter(finalOutputFile, StandardCharsets.UTF_8)) {
String currentWord = null;
long currentCount = 0;
while (!pq.isEmpty()) {
WordFileEntry entry = pq.poll();
if (currentWord == null) currentWord = entry.word;
if (!entry.word.equals(currentWord)) {
writer.write(currentWord + "\t" + currentCount);
writer.newLine();
currentWord = entry.word;
currentCount = 0;
}
currentCount += entry.count;
String nextLine = entry.reader.readLine();
if (nextLine != null) {
pq.add(new WordFileEntry(nextLine, entry.reader));
}
}
if (currentWord != null) {
writer.write(currentWord + "\t" + currentCount);
writer.newLine();
}
}
metrics.finalFileBytes.set(Files.size(finalOutputFile));
} finally {
for (BufferedReader reader : readers) {
try { reader.close(); } catch (IOException e) { /* ignore */ }
}
}
}
private static List<Map.Entry<String, Long>> findTopN(Path finalFile, int n, PerformanceMetrics metrics) throws IOException {
PriorityQueue<Map.Entry<String, Long>> topNHeap = new PriorityQueue<>(Map.Entry.comparingByValue());
final AtomicLong uniqueWordCounter = new AtomicLong(0);
try (Stream<String> lines = Files.lines(finalFile, StandardCharsets.UTF_8)) {
lines.forEach(line -> {
uniqueWordCounter.incrementAndGet();
String[] parts = line.split("\t");
if (parts.length == 2) {
String word = parts[0];
long count = Long.parseLong(parts[1]);
if (topNHeap.size() < n) {
topNHeap.add(new AbstractMap.SimpleEntry<>(word, count));
} else if (count > topNHeap.peek().getValue()) {
topNHeap.poll();
topNHeap.add(new AbstractMap.SimpleEntry<>(word, count));
}
}
});
}
metrics.totalUniqueWords.set(uniqueWordCounter.get());
List<Map.Entry<String, Long>> result = new ArrayList<>(topNHeap);
result.sort(Map.Entry.comparingByValue(Comparator.reverseOrder()));
return result;
}
private static class WordFileEntry {
final String word;
final long count;
final BufferedReader reader;
WordFileEntry(String line, BufferedReader reader) {
String[] parts = line.split("\t");
this.word = parts[0];
this.count = Long.parseLong(parts[1]);
this.reader = reader;
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment