121 lines
4.4 KiB
TypeScript
121 lines
4.4 KiB
TypeScript
import { publicProcedure, router } from "../trpc";
|
|
import { TRPCError } from "@trpc/server";
|
|
import { db } from '~/server/db'
|
|
import { chatMessage,
|
|
chatSession, systemSettings } from "../dbschema/schema";
|
|
import { isAdmin } from '~/app/actions';
|
|
import { z } from 'zod';
|
|
import { eq } from 'drizzle-orm';
|
|
import { clerkClient, auth } from '@clerk/nextjs/server'
|
|
import { env } from '~/env'
|
|
|
|
export const DEFAULT_MODEL = 'gpt-5-mini'
|
|
|
|
// Models returned by the OpenAI API that aren't usable for chat completions.
|
|
const NON_CHAT_MODEL = /embedding|image|audio|realtime|transcribe|tts|whisper|moderation|dall-e|search|codex|instruct/
|
|
|
|
async function readSettings() {
|
|
return db.select().from(systemSettings).limit(1).then((r) => r[0])
|
|
}
|
|
|
|
async function writeSettings(values: { systemPropmt?: string | null; model?: string | null }) {
|
|
const current = await readSettings()
|
|
await db.delete(systemSettings)
|
|
await db.insert(systemSettings).values({
|
|
systemPropmt: values.systemPropmt ?? current?.systemPropmt ?? null,
|
|
model: values.model ?? current?.model ?? null,
|
|
})
|
|
}
|
|
|
|
export const chatRouter = router({
|
|
getSession: publicProcedure.query(async () => {
|
|
const { userId } = await auth();
|
|
if (!userId) {
|
|
throw new TRPCError({ message: "chat is only available to signed in users", code: 'UNAUTHORIZED' });
|
|
}
|
|
const clerk = await clerkClient()
|
|
const user = await clerk.users.getUser(userId)
|
|
let session = await db.query.chatSession.findFirst({
|
|
where(fields, operators) {
|
|
return operators.eq(fields.userId, user.id)
|
|
},
|
|
})
|
|
if (session !== undefined) {
|
|
return session;
|
|
}
|
|
let newSession = await db.insert(chatSession).values({ userId: user.id}).returning().execute().then((r) => r.at(0)); if (newSession == undefined) {
|
|
throw new TRPCError({ message: "failed to create session", code: "INTERNAL_SERVER_ERROR" });
|
|
}
|
|
session = await db.query.chatSession.findFirst({
|
|
where(fields, operators) {
|
|
return operators.eq(fields.userId, user.id)
|
|
},
|
|
})
|
|
if (session == undefined) {
|
|
throw new TRPCError({ message: "session not found", code: "NOT_FOUND" });
|
|
}
|
|
if (session !== undefined) {
|
|
return session;
|
|
}
|
|
}),
|
|
getMessages: publicProcedure.input(z.string()).query(async ({input}) => {
|
|
let res = await db.query.chatMessage.findMany({
|
|
where(fields,operators) {
|
|
return operators.eq(fields.sessionId,input)
|
|
}
|
|
})
|
|
return res;
|
|
}),
|
|
clearChat: publicProcedure.mutation(async () => {
|
|
console.log("deleting session")
|
|
const { userId } = await auth();
|
|
if (userId == null) {
|
|
throw new TRPCError({ message: "chat is only available to signed in users", code: 'UNAUTHORIZED' });
|
|
}
|
|
let session = await db.query.chatSession.findFirst({
|
|
with: {
|
|
messages: true
|
|
},
|
|
where(fields, operators) {
|
|
return operators.eq(fields.userId, userId)
|
|
},
|
|
})
|
|
if (session != undefined) {
|
|
db.delete(chatMessage).where(eq(chatMessage.sessionId,session.id)).execute()
|
|
}
|
|
|
|
}),
|
|
getSystemPrompt: publicProcedure.query(async () => {
|
|
const row = await readSettings()
|
|
return row?.systemPropmt ?? ''
|
|
}),
|
|
updateSystemPrompt: publicProcedure.input(z.object({ prompt: z.string() })).mutation(async ({ input }) => {
|
|
if (!(await isAdmin())) throw new TRPCError({ code: 'FORBIDDEN' })
|
|
await writeSettings({ systemPropmt: input.prompt })
|
|
}),
|
|
getModel: publicProcedure.query(async () => {
|
|
const row = await readSettings()
|
|
return row?.model ?? DEFAULT_MODEL
|
|
}),
|
|
listModels: publicProcedure.query(async () => {
|
|
if (!(await isAdmin())) throw new TRPCError({ code: 'FORBIDDEN' })
|
|
const res = await fetch('https://api.openai.com/v1/models', {
|
|
headers: { Authorization: `Bearer ${env.OPENAI_API_KEY}` },
|
|
})
|
|
if (!res.ok) {
|
|
throw new TRPCError({ code: 'INTERNAL_SERVER_ERROR', message: `failed to fetch models (${res.status})` })
|
|
}
|
|
const json = (await res.json()) as { data: { id: string }[] }
|
|
return json.data
|
|
.map((m) => m.id)
|
|
.filter((id) => (id.startsWith('gpt') || /^o\d/.test(id) || id.startsWith('chatgpt')) && !NON_CHAT_MODEL.test(id))
|
|
.sort()
|
|
}),
|
|
updateModel: publicProcedure.input(z.object({ model: z.string() })).mutation(async ({ input }) => {
|
|
if (!(await isAdmin())) throw new TRPCError({ code: 'FORBIDDEN' })
|
|
await writeSettings({ model: input.model })
|
|
}),
|
|
})
|
|
|
|
export type ChatRouter = typeof chatRouter;
|