Skip to content

Instantly share code, notes, and snippets.

@eric-humane
Last active October 16, 2024 23:11
Show Gist options
  • Select an option

  • Save eric-humane/bc4a8853fe29a95305b5f22feec3d74c to your computer and use it in GitHub Desktop.

Select an option

Save eric-humane/bc4a8853fe29a95305b5f22feec3d74c to your computer and use it in GitHub Desktop.
const swarm = new Swarm();
const agentB = new Agent({
name: "Spanish Agent",
instructions: "You only speak Spanish.",
});
const agentA = new Agent({
name: "English Agent",
instructions: "You only speak English.",
tools: [
Tool.TransferToAgent(
agentB,
"Transfer spanish speaking users immediately.",
),
],
});
try {
const result = await swarm.run(
agentA,
[{ role: "user", content: "Hola. ¿Como estás?" }],
{},
);
console.log(result);
} catch (error) {
console.log(error);
}
{
"name": "swarm",
"version": "0.0.1",
"private": true,
"dependencies": {
"openai": "^4.67.3",
"zod": "^3.23.8",
"zod-to-openai-tool": "^0.13.1"
},
}
import { t, createTools } from "zod-to-openai-tool";
import { z } from "zod";
import {
ChatCompletionCreateParams,
ChatCompletionMessageParam,
ChatCompletionMessageToolCall,
} from "openai/resources";
import { OpenAI } from "openai";
interface DefaultContext {
[key: string]: unknown;
}
export class Tool {
name: string;
description: string;
input?: z.AnyZodObject;
func: AgentFunction;
constructor(params: {
name: string;
description: string;
input?: z.AnyZodObject;
func: AgentFunction;
}) {
this.name = params.name;
this.description = params.description;
this.input = params.input;
this.func = params.func;
}
static TransferToAgent = (agent: Agent, description?: string) => {
return new Tool({
name: `transfer_to_${agent.name.toLowerCase().replaceAll(" ", "_")}`,
description,
func: () => {
return agent;
},
});
};
}
export class Agent {
static DEFAULT_MODEL = "gpt-4o";
static DEFAULT_INSTRUCTIONS = "You are a helpful assistant.";
name: string;
model: string;
instructions: string | ((input: Record<string, any>) => string);
tools: Tool[];
toolChoice?: string;
parallelToolCalls: boolean;
constructor({
name,
model = Agent.DEFAULT_MODEL,
instructions = Agent.DEFAULT_INSTRUCTIONS,
tools = [],
toolChoice,
parallelToolCalls = true,
}: {
name: string;
model?: string;
instructions?: string | ((input: Record<string, any>) => string);
tools?: Tool[];
toolChoice?: string;
parallelToolCalls?: boolean;
}) {
this.name = name;
this.model = model;
this.instructions = instructions;
this.tools = tools;
this.toolChoice = toolChoice;
this.parallelToolCalls = parallelToolCalls;
}
getInstructions(contextVariables: Record<string, any>): string {
return typeof this.instructions === "function"
? this.instructions(contextVariables)
: this.instructions;
}
}
export class Result<TContext extends DefaultContext = DefaultContext> {
value: string;
agent?: Agent;
contextVariables: TContext;
constructor(params: { value: string; agent?: Agent; contextVariables: TContext }) {
this.value = params.value;
this.agent = params.agent;
this.contextVariables = params.contextVariables;
}
}
export class Response<TContext extends DefaultContext = DefaultContext> {
messages: ChatCompletionMessageParam[];
agent?: Agent;
contextVariables: TContext;
constructor(params: {
messages: ChatCompletionMessageParam[];
agent?: Agent;
contextVariables: TContext;
}) {
this.messages = params.messages;
this.agent = params.agent;
this.contextVariables = params.contextVariables;
}
addMessage(message: ChatCompletionMessageParam) {
this.messages.push(message);
}
updateContextVariables(newVars: Partial<TContext>) {
this.contextVariables = { ...this.contextVariables, ...newVars };
}
setAgent(agent: Agent) {
this.agent = agent;
}
}
type AgentFunction<
Args = any,
TContext extends DefaultContext = DefaultContext,
ReturnType = string | Agent | unknown,
> = (args: Args, ctx: TContext) => Promise<ReturnType> | ReturnType;
export class Swarm<TContext extends DefaultContext = DefaultContext> {
client: OpenAI;
constructor(client = new OpenAI({ apiKey: process.env.OPENAI_API_KEY })) {
this.client = client;
}
private getChatCompletion(
agent: Agent,
history: ChatCompletionMessageParam[],
contextVariables: TContext,
modelOverride?: string
) {
console.log("Getting completion from agent: " + agent.name);
const instructions = agent.getInstructions(contextVariables);
const messages: ChatCompletionMessageParam[] = [
{ role: "system", content: instructions },
...history,
];
const tools = agent.tools.map(({ name, description, input }) => {
const {
tools: [tool],
} = createTools({
[name]: t
.input(input ?? z.object({}))
.describe(description)
.run(() => {}),
});
return tool;
});
const createParams: ChatCompletionCreateParams = {
model: modelOverride ?? agent.model,
messages,
stream: false,
...(tools.length > 0 && {
tools,
parallel_tool_calls: agent.parallelToolCalls,
}),
...(agent.toolChoice && {
tool_choice: { type: "function", function: { name: agent.toolChoice } },
}),
};
return this.client.chat.completions.create(createParams);
}
private async handleToolCalls(
toolCalls: ChatCompletionMessageToolCall[],
tools: Tool[],
contextVariables: TContext
) {
const partialResponse = new Response<TContext>({
messages: [],
contextVariables: { ...contextVariables },
});
const functionMap = new Map(tools.map((func) => [func.name, func]));
// Process tool calls in parallel and collect results
const toolCallResults = await Promise.all(
toolCalls.map(async (toolCall) => {
const tool_call_id = toolCall.id;
const name = toolCall.function.name;
const tool = functionMap.get(name);
function msg(content: string): ChatCompletionMessageParam {
return {
role: "tool",
tool_call_id,
content,
};
}
if (!tool) {
console.error(`Tool ${name} not found in function map.`);
return {
message: msg(`Error: Tool ${name} not found.`),
};
}
let args: Record<string, any>;
try {
args = JSON.parse(toolCall.function.arguments);
} catch (error: any) {
console.error(`Failed to parse arguments for tool ${name}:`, error);
return {
message: msg(`Error: Invalid arguments for tool ${name}. Details: ${error.message}`),
};
}
console.log(`Processing tool call: ${name} with arguments`, args);
try {
const result = this.handleResult(await tool.func(args, contextVariables));
return {
message: msg(result.value),
contextVariables: result.contextVariables,
agent: result.agent,
};
} catch (error: any) {
console.error(`Error executing tool ${name}:`, error);
return {
message: msg(`Error executing tool ${name}. Details: ${error.message}`),
};
}
})
);
// Combine results after all tool calls have completed
for (const result of toolCallResults) {
if (result.message) {
partialResponse.addMessage(result.message);
}
partialResponse.updateContextVariables({
...partialResponse.contextVariables,
...result.contextVariables,
});
if (result.agent) {
partialResponse.setAgent(result.agent);
}
}
return partialResponse;
}
private handleResult(result: Result | Agent | unknown): Result<TContext> {
if (result instanceof Result) {
return result;
} else if (result instanceof Agent) {
return new Result<TContext>({
value: JSON.stringify({ assistant: result.name }),
agent: result,
contextVariables: {} as TContext,
});
} else {
return new Result<TContext>({
value: JSON.stringify(result),
contextVariables: {} as TContext,
});
}
}
public async run(
agent: Agent,
messages: ChatCompletionMessageParam[],
contextVariables: TContext,
maxTurns: number = 25,
executeTools: boolean = true,
modelOverride?: string,
historyLimit: number = 100
): Promise<Response<TContext>> {
if (!Number.isInteger(maxTurns) || maxTurns <= 0) {
throw new Error("maxTurns must be a positive integer");
}
if (!Number.isInteger(historyLimit) || historyLimit <= 0) {
throw new Error("historyLimit must be a positive integer");
}
if (typeof contextVariables !== "object" || contextVariables === null) {
throw new Error("contextVariables must be a non-null object");
}
let activeAgent = agent;
let history = [...messages];
const initLen = messages.length;
while (history.length - initLen < maxTurns && activeAgent) {
// Truncate history if it exceeds the limit
if (history.length > historyLimit) {
history = [...history.slice(0, 1), ...history.slice(-(historyLimit - 1))];
}
const completion = await this.getChatCompletion(
activeAgent,
history,
contextVariables,
modelOverride
);
const message = completion.choices[0].message;
console.log("Received completion:", message);
history.push(message);
if (!message.tool_calls || !executeTools) {
console.log("Ending turn.");
break;
}
const partialResponse = await this.handleToolCalls(
message.tool_calls,
activeAgent.tools,
contextVariables
);
history = history.concat(partialResponse.messages);
partialResponse.updateContextVariables(contextVariables);
if (partialResponse.agent) {
activeAgent = partialResponse.agent;
}
}
return new Response({
messages: history.slice(initLen),
agent: activeAgent,
contextVariables,
});
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment