Compare commits
1 Commits
b59fb2b3af
...
ai-model-s
| Author | SHA1 | Date | |
|---|---|---|---|
| 54f108ac8d |
48
src/app/admin/chat/_components/ModelSelector.tsx
Normal file
48
src/app/admin/chat/_components/ModelSelector.tsx
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
'use client'
|
||||||
|
import { trpc } from '~/app/_trpc/Client'
|
||||||
|
import {
|
||||||
|
Select,
|
||||||
|
SelectContent,
|
||||||
|
SelectItem,
|
||||||
|
SelectTrigger,
|
||||||
|
SelectValue,
|
||||||
|
} from '~/components/ui/select'
|
||||||
|
|
||||||
|
export default function ModelSelector({ initialValue }: { initialValue: string }) {
|
||||||
|
const utils = trpc.useUtils()
|
||||||
|
const { data: models, isLoading, error } = trpc.chat.listModels.useQuery()
|
||||||
|
const { data: model = initialValue } = trpc.chat.getModel.useQuery(undefined, {
|
||||||
|
initialData: initialValue,
|
||||||
|
})
|
||||||
|
|
||||||
|
const mutation = trpc.chat.updateModel.useMutation({
|
||||||
|
onSuccess: () => utils.chat.getModel.invalidate(),
|
||||||
|
})
|
||||||
|
|
||||||
|
// Ensure the currently-saved model is always selectable, even if the
|
||||||
|
// OpenAI list doesn't include it (e.g. a deprecated model).
|
||||||
|
const options = Array.from(new Set([model, ...(models ?? [])])).filter(Boolean)
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="flex flex-col gap-2">
|
||||||
|
<Select value={model} onValueChange={(v) => mutation.mutate({ model: v })}>
|
||||||
|
<SelectTrigger className="w-72">
|
||||||
|
<SelectValue placeholder={isLoading ? 'Loading models…' : 'Select a model'} />
|
||||||
|
</SelectTrigger>
|
||||||
|
<SelectContent>
|
||||||
|
{options.map((id) => (
|
||||||
|
<SelectItem key={id} value={id}>
|
||||||
|
{id}
|
||||||
|
</SelectItem>
|
||||||
|
))}
|
||||||
|
</SelectContent>
|
||||||
|
</Select>
|
||||||
|
<div className="flex items-center gap-3 text-sm text-muted-foreground">
|
||||||
|
{mutation.isPending && <span>Saving…</span>}
|
||||||
|
{mutation.isSuccess && !mutation.isPending && <span>Saved</span>}
|
||||||
|
{error && <span className="text-destructive">Failed to load models: {error.message}</span>}
|
||||||
|
{mutation.error && <span className="text-destructive">{mutation.error.message}</span>}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
@@ -1,18 +1,31 @@
|
|||||||
import { servTrpc } from '~/app/_trpc/ServerClient'
|
import { servTrpc } from '~/app/_trpc/ServerClient'
|
||||||
import SystemPromptForm from './_components/SystemPromptForm'
|
import SystemPromptForm from './_components/SystemPromptForm'
|
||||||
|
import ModelSelector from './_components/ModelSelector'
|
||||||
|
|
||||||
export default async function SystemPromptPage() {
|
export default async function SystemPromptPage() {
|
||||||
const prompt = await servTrpc.chat.getSystemPrompt()
|
const prompt = await servTrpc.chat.getSystemPrompt()
|
||||||
|
const model = await servTrpc.chat.getModel()
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="w-full max-w-2xl p-6 flex flex-col gap-4">
|
<div className="w-full max-w-2xl p-6 flex flex-col gap-8">
|
||||||
<div>
|
<div className="flex flex-col gap-4">
|
||||||
<h1 className="text-lg font-semibold">AI System Prompt</h1>
|
<div>
|
||||||
<p className="text-sm text-muted-foreground">
|
<h1 className="text-lg font-semibold">AI Model</h1>
|
||||||
This prompt is sent to the model on every chat request.
|
<p className="text-sm text-muted-foreground">
|
||||||
</p>
|
The OpenAI model used to respond to chat requests.
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<ModelSelector initialValue={model} />
|
||||||
|
</div>
|
||||||
|
<div className="flex flex-col gap-4">
|
||||||
|
<div>
|
||||||
|
<h1 className="text-lg font-semibold">AI System Prompt</h1>
|
||||||
|
<p className="text-sm text-muted-foreground">
|
||||||
|
This prompt is sent to the model on every chat request.
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<SystemPromptForm initialValue={prompt} />
|
||||||
</div>
|
</div>
|
||||||
<SystemPromptForm initialValue={prompt} />
|
|
||||||
</div>
|
</div>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ export async function POST(req: Request) {
|
|||||||
if (!session) return new Response('Session not found', { status: 404 })
|
if (!session) return new Response('Session not found', { status: 404 })
|
||||||
|
|
||||||
const systemPrompt = await servTrpc.chat.getSystemPrompt() || 'You are an AI recruiter assistant.'
|
const systemPrompt = await servTrpc.chat.getSystemPrompt() || 'You are an AI recruiter assistant.'
|
||||||
|
const model = await servTrpc.chat.getModel()
|
||||||
|
|
||||||
// Save the latest user message
|
// Save the latest user message
|
||||||
const lastMessage = messages[messages.length - 1]
|
const lastMessage = messages[messages.length - 1]
|
||||||
@@ -46,7 +47,7 @@ export async function POST(req: Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const result = streamText({
|
const result = streamText({
|
||||||
model: openai('gpt-5-mini'),
|
model: openai(model),
|
||||||
system: systemPrompt,
|
system: systemPrompt,
|
||||||
messages: await convertToModelMessages(messages),
|
messages: await convertToModelMessages(messages),
|
||||||
tools: {
|
tools: {
|
||||||
|
|||||||
@@ -175,6 +175,7 @@ export const chatMessageRelations = relations(chatMessage, ({ one }) => ({
|
|||||||
export const systemSettings = createTable(
|
export const systemSettings = createTable(
|
||||||
"systemSetting",
|
"systemSetting",
|
||||||
(d) => ({
|
(d) => ({
|
||||||
systemPropmt: d.text()
|
systemPropmt: d.text(),
|
||||||
|
model: d.text()
|
||||||
})
|
})
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -7,6 +7,26 @@ import { isAdmin } from '~/app/actions';
|
|||||||
import { z } from 'zod';
|
import { z } from 'zod';
|
||||||
import { eq } from 'drizzle-orm';
|
import { eq } from 'drizzle-orm';
|
||||||
import { clerkClient, auth } from '@clerk/nextjs/server'
|
import { clerkClient, auth } from '@clerk/nextjs/server'
|
||||||
|
import { env } from '~/env'
|
||||||
|
|
||||||
|
export const DEFAULT_MODEL = 'gpt-5-mini'
|
||||||
|
|
||||||
|
// Models returned by the OpenAI API that aren't usable for chat completions.
|
||||||
|
const NON_CHAT_MODEL = /embedding|image|audio|realtime|transcribe|tts|whisper|moderation|dall-e|search|codex|instruct/
|
||||||
|
|
||||||
|
async function readSettings() {
|
||||||
|
return db.select().from(systemSettings).limit(1).then((r) => r[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
async function writeSettings(values: { systemPropmt?: string | null; model?: string | null }) {
|
||||||
|
const current = await readSettings()
|
||||||
|
await db.delete(systemSettings)
|
||||||
|
await db.insert(systemSettings).values({
|
||||||
|
systemPropmt: values.systemPropmt ?? current?.systemPropmt ?? null,
|
||||||
|
model: values.model ?? current?.model ?? null,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
export const chatRouter = router({
|
export const chatRouter = router({
|
||||||
getSession: publicProcedure.query(async () => {
|
getSession: publicProcedure.query(async () => {
|
||||||
const { userId } = await auth();
|
const { userId } = await auth();
|
||||||
@@ -66,13 +86,34 @@ export const chatRouter = router({
|
|||||||
|
|
||||||
}),
|
}),
|
||||||
getSystemPrompt: publicProcedure.query(async () => {
|
getSystemPrompt: publicProcedure.query(async () => {
|
||||||
const row = await db.select().from(systemSettings).limit(1).then((r) => r[0])
|
const row = await readSettings()
|
||||||
return row?.systemPropmt ?? ''
|
return row?.systemPropmt ?? ''
|
||||||
}),
|
}),
|
||||||
updateSystemPrompt: publicProcedure.input(z.object({ prompt: z.string() })).mutation(async ({ input }) => {
|
updateSystemPrompt: publicProcedure.input(z.object({ prompt: z.string() })).mutation(async ({ input }) => {
|
||||||
if (!(await isAdmin())) throw new TRPCError({ code: 'FORBIDDEN' })
|
if (!(await isAdmin())) throw new TRPCError({ code: 'FORBIDDEN' })
|
||||||
await db.delete(systemSettings)
|
await writeSettings({ systemPropmt: input.prompt })
|
||||||
await db.insert(systemSettings).values({ systemPropmt: input.prompt })
|
}),
|
||||||
|
getModel: publicProcedure.query(async () => {
|
||||||
|
const row = await readSettings()
|
||||||
|
return row?.model ?? DEFAULT_MODEL
|
||||||
|
}),
|
||||||
|
listModels: publicProcedure.query(async () => {
|
||||||
|
if (!(await isAdmin())) throw new TRPCError({ code: 'FORBIDDEN' })
|
||||||
|
const res = await fetch('https://api.openai.com/v1/models', {
|
||||||
|
headers: { Authorization: `Bearer ${env.OPENAI_API_KEY}` },
|
||||||
|
})
|
||||||
|
if (!res.ok) {
|
||||||
|
throw new TRPCError({ code: 'INTERNAL_SERVER_ERROR', message: `failed to fetch models (${res.status})` })
|
||||||
|
}
|
||||||
|
const json = (await res.json()) as { data: { id: string }[] }
|
||||||
|
return json.data
|
||||||
|
.map((m) => m.id)
|
||||||
|
.filter((id) => (id.startsWith('gpt') || /^o\d/.test(id) || id.startsWith('chatgpt')) && !NON_CHAT_MODEL.test(id))
|
||||||
|
.sort()
|
||||||
|
}),
|
||||||
|
updateModel: publicProcedure.input(z.object({ model: z.string() })).mutation(async ({ input }) => {
|
||||||
|
if (!(await isAdmin())) throw new TRPCError({ code: 'FORBIDDEN' })
|
||||||
|
await writeSettings({ model: input.model })
|
||||||
}),
|
}),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user