Thomas G. Lopes commited on
Commit
3b86586
·
1 Parent(s): 078734b

fix structured output detection

Browse files
src/lib/components/inference-playground/code-snippets.svelte CHANGED
@@ -1,6 +1,5 @@
1
  <script lang="ts">
2
  import { type ConversationClass } from "$lib/state/conversations.svelte";
3
- import { structuredForbiddenProviders } from "$lib/state/models.svelte";
4
  import { token } from "$lib/state/token.svelte.js";
5
  import { billing } from "$lib/state/billing.svelte";
6
  import { isCustomModel } from "$lib/types.js";
@@ -59,7 +58,7 @@
59
  // eslint-disable-next-line @typescript-eslint/no-explicit-any
60
  } as any;
61
  // eslint-disable-next-line @typescript-eslint/no-explicit-any
62
- if (data.structuredOutput && !structuredForbiddenProviders.includes(conversation.data.provider as any)) {
63
  opts.structured_output = data.structuredOutput;
64
  }
65
 
 
1
  <script lang="ts">
2
  import { type ConversationClass } from "$lib/state/conversations.svelte";
 
3
  import { token } from "$lib/state/token.svelte.js";
4
  import { billing } from "$lib/state/billing.svelte";
5
  import { isCustomModel } from "$lib/types.js";
 
58
  // eslint-disable-next-line @typescript-eslint/no-explicit-any
59
  } as any;
60
  // eslint-disable-next-line @typescript-eslint/no-explicit-any
61
+ if (data.structuredOutput && conversation.isStructuredOutputAllowed) {
62
  opts.structured_output = data.structuredOutput;
63
  }
64
 
src/lib/components/inference-playground/generation-config.svelte CHANGED
@@ -1,14 +1,13 @@
1
  <script lang="ts">
2
  import type { ConversationClass } from "$lib/state/conversations.svelte.js";
3
- import { structuredForbiddenProviders } from "$lib/state/models.svelte.js";
4
  import { maxAllowedTokens } from "$lib/utils/business.svelte.js";
 
5
  import { isNumber } from "$lib/utils/is.js";
6
  import { watch } from "runed";
7
  import IconX from "~icons/carbon/close";
 
8
  import { GENERATION_CONFIG_KEYS, GENERATION_CONFIG_SETTINGS } from "./generation-config-settings.js";
9
  import StructuredOutputModal, { openStructuredOutputModal } from "./structured-output-modal.svelte";
10
- import ExtraParamsModal, { openExtraParamsModal } from "./extra-params-modal.svelte";
11
- import { cn } from "$lib/utils/cn.js";
12
 
13
  interface Props {
14
  conversation: ConversationClass;
@@ -103,7 +102,7 @@
103
  </label>
104
 
105
  <!-- eslint-disable-next-line @typescript-eslint/no-explicit-any -->
106
- {#if !structuredForbiddenProviders.includes(conversation.data.provider as any)}
107
  <label class="mt-2 flex cursor-pointer items-center justify-between" for="structured-output">
108
  <span class="text-sm font-medium text-gray-900 dark:text-gray-300">Structured Output</span>
109
  <div class="flex items-center gap-2">
 
1
  <script lang="ts">
2
  import type { ConversationClass } from "$lib/state/conversations.svelte.js";
 
3
  import { maxAllowedTokens } from "$lib/utils/business.svelte.js";
4
+ import { cn } from "$lib/utils/cn.js";
5
  import { isNumber } from "$lib/utils/is.js";
6
  import { watch } from "runed";
7
  import IconX from "~icons/carbon/close";
8
+ import ExtraParamsModal, { openExtraParamsModal } from "./extra-params-modal.svelte";
9
  import { GENERATION_CONFIG_KEYS, GENERATION_CONFIG_SETTINGS } from "./generation-config-settings.js";
10
  import StructuredOutputModal, { openStructuredOutputModal } from "./structured-output-modal.svelte";
 
 
11
 
12
  interface Props {
13
  conversation: ConversationClass;
 
102
  </label>
103
 
104
  <!-- eslint-disable-next-line @typescript-eslint/no-explicit-any -->
105
+ {#if conversation.isStructuredOutputAllowed}
106
  <label class="mt-2 flex cursor-pointer items-center justify-between" for="structured-output">
107
  <span class="text-sm font-medium text-gray-900 dark:text-gray-300">Structured Output</span>
108
  <div class="flex items-center gap-2">
src/lib/state/conversations.svelte.ts CHANGED
@@ -4,10 +4,10 @@ import {
4
  } from "$lib/components/inference-playground/generation-config-settings.js";
5
  import { addToast } from "$lib/components/toaster.svelte.js";
6
  import { AbortManager } from "$lib/spells/abort-manager.svelte";
7
- import { PipelineTag, Provider, type ConversationMessage, type GenerationStatistics, type Model } from "$lib/types.js";
8
  import { handleNonStreamingResponse, handleStreamingResponse, estimateTokens } from "$lib/utils/business.svelte.js";
9
  import { omit, snapshot } from "$lib/utils/object.svelte";
10
- import { models, structuredForbiddenProviders } from "./models.svelte";
11
  import { pricing } from "./pricing.svelte.js";
12
  import { DEFAULT_PROJECT_ID, ProjectEntity, projects } from "./projects.svelte";
13
  import { token } from "./token.svelte";
@@ -107,9 +107,7 @@ export class ConversationClass {
107
  }
108
 
109
  get isStructuredOutputAllowed() {
110
- const forbiddenProvider =
111
- this.data.provider && structuredForbiddenProviders.includes(this.data.provider as Provider);
112
- return !forbiddenProvider;
113
  }
114
 
115
  get isStructuredOutputEnabled() {
 
4
  } from "$lib/components/inference-playground/generation-config-settings.js";
5
  import { addToast } from "$lib/components/toaster.svelte.js";
6
  import { AbortManager } from "$lib/spells/abort-manager.svelte";
7
+ import { PipelineTag, type ConversationMessage, type GenerationStatistics, type Model } from "$lib/types.js";
8
  import { handleNonStreamingResponse, handleStreamingResponse, estimateTokens } from "$lib/utils/business.svelte.js";
9
  import { omit, snapshot } from "$lib/utils/object.svelte";
10
+ import { models } from "./models.svelte";
11
  import { pricing } from "./pricing.svelte.js";
12
  import { DEFAULT_PROJECT_ID, ProjectEntity, projects } from "./projects.svelte";
13
  import { token } from "./token.svelte";
 
107
  }
108
 
109
  get isStructuredOutputAllowed() {
110
+ return models.supportsStructuredOutput(this.model, this.data.provider);
 
 
111
  }
112
 
113
  get isStructuredOutputEnabled() {
src/lib/state/models.svelte.ts CHANGED
@@ -1,5 +1,5 @@
1
  import { page } from "$app/state";
2
- import { Provider, type CustomModel } from "$lib/types.js";
3
  import { edit, randomPick } from "$lib/utils/array.js";
4
  import { safeParse } from "$lib/utils/json.js";
5
  import typia from "typia";
@@ -10,13 +10,6 @@ const LOCAL_STORAGE_KEY = "hf_inference_playground_custom_models";
10
 
11
  const pageData = $derived(page.data as PageData);
12
 
13
- export const structuredForbiddenProviders: Provider[] = [
14
- Provider.Hyperbolic,
15
- Provider.Nebius,
16
- Provider.Novita,
17
- Provider.Sambanova,
18
- ];
19
-
20
  class Models {
21
  remote = $derived(pageData.models);
22
  trending = $derived(this.remote.toSorted((a, b) => b.trendingScore - a.trendingScore).slice(0, 5));
@@ -74,6 +67,13 @@ class Models {
74
  c.update({ modelId: randomPick(models.trending)?.id });
75
  });
76
  }
 
 
 
 
 
 
 
77
  }
78
 
79
  export const models = new Models();
 
1
  import { page } from "$app/state";
2
+ import { Provider, type CustomModel, type Model } from "$lib/types.js";
3
  import { edit, randomPick } from "$lib/utils/array.js";
4
  import { safeParse } from "$lib/utils/json.js";
5
  import typia from "typia";
 
10
 
11
  const pageData = $derived(page.data as PageData);
12
 
 
 
 
 
 
 
 
13
  class Models {
14
  remote = $derived(pageData.models);
15
  trending = $derived(this.remote.toSorted((a, b) => b.trendingScore - a.trendingScore).slice(0, 5));
 
67
  c.update({ modelId: randomPick(models.trending)?.id });
68
  });
69
  }
70
+
71
+ supportsStructuredOutput(model: Model | CustomModel, provider?: string) {
72
+ if (typia.is<CustomModel>(model)) return true;
73
+ const routerDataEntry = pageData.routerData.data.find(d => d.id === model.id);
74
+ if (!routerDataEntry) return false;
75
+ return routerDataEntry.providers.find(p => p.provider === provider)?.supports_structured_output ?? false;
76
+ }
77
  }
78
 
79
  export const models = new Models();
src/lib/utils/business.svelte.ts CHANGED
@@ -28,8 +28,8 @@ import { AutoTokenizer, PreTrainedTokenizer } from "@huggingface/transformers";
28
  import OpenAI from "openai";
29
  import { images } from "$lib/state/images.svelte.js";
30
  import { projects } from "$lib/state/projects.svelte.js";
31
- import { structuredForbiddenProviders } from "$lib/state/models.svelte.js";
32
  import { modifySnippet } from "$lib/utils/snippets.js";
 
33
 
34
  type ChatCompletionInputMessageChunk =
35
  NonNullable<ChatCompletionInputMessage["content"]> extends string | (infer U)[] ? U : never;
@@ -89,7 +89,7 @@ function getResponseFormatObj(conversation: ConversationClass | Conversation) {
89
  const data = conversation instanceof ConversationClass ? conversation.data : conversation;
90
  const json = safeParse(data.structuredOutput?.schema ?? "");
91
  // eslint-disable-next-line @typescript-eslint/no-explicit-any
92
- if (json && data.structuredOutput?.enabled && !structuredForbiddenProviders.includes(data.provider as any)) {
93
  switch (data.provider) {
94
  case "cohere": {
95
  return {
@@ -366,7 +366,7 @@ export function getInferenceSnippet(
366
  if (
367
  opts?.structured_output?.schema &&
368
  opts.structured_output.enabled &&
369
- !structuredForbiddenProviders.includes(provider as Provider)
370
  ) {
371
  return {
372
  ...s,
 
28
  import OpenAI from "openai";
29
  import { images } from "$lib/state/images.svelte.js";
30
  import { projects } from "$lib/state/projects.svelte.js";
 
31
  import { modifySnippet } from "$lib/utils/snippets.js";
32
+ import { models } from "$lib/state/models.svelte";
33
 
34
  type ChatCompletionInputMessageChunk =
35
  NonNullable<ChatCompletionInputMessage["content"]> extends string | (infer U)[] ? U : never;
 
89
  const data = conversation instanceof ConversationClass ? conversation.data : conversation;
90
  const json = safeParse(data.structuredOutput?.schema ?? "");
91
  // eslint-disable-next-line @typescript-eslint/no-explicit-any
92
+ if (json && data.structuredOutput?.enabled && models.supportsStructuredOutput(conversation.model, data.provider)) {
93
  switch (data.provider) {
94
  case "cohere": {
95
  return {
 
366
  if (
367
  opts?.structured_output?.schema &&
368
  opts.structured_output.enabled &&
369
+ models.supportsStructuredOutput(conversation.model, provider)
370
  ) {
371
  return {
372
  ...s,
src/routes/+page.ts CHANGED
@@ -1,6 +1,36 @@
 
1
  import type { PageLoad } from "./$types.js";
2
  import type { ApiModelsResponse } from "./api/models/+server.js";
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  export const load: PageLoad = async ({ fetch }) => {
5
  const [modelsRes, routerRes] = await Promise.all([
6
  fetch("/api/models"),
@@ -8,7 +38,7 @@ export const load: PageLoad = async ({ fetch }) => {
8
  ]);
9
 
10
  const models: ApiModelsResponse = await modelsRes.json();
11
- const routerData = await routerRes.json();
12
 
13
  return {
14
  ...models,
 
1
+ import type { Provider } from "$lib/types.js";
2
  import type { PageLoad } from "./$types.js";
3
  import type { ApiModelsResponse } from "./api/models/+server.js";
4
 
5
+ export type RouterData = {
6
+ object: string;
7
+ data: Datum[];
8
+ };
9
+
10
+ type Datum = {
11
+ id: string;
12
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
13
+ object: any;
14
+ created: number;
15
+ owned_by: string;
16
+ providers: ProviderElement[];
17
+ };
18
+
19
+ type ProviderElement = {
20
+ provider: Provider;
21
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
22
+ status: any;
23
+ context_length?: number;
24
+ pricing?: Pricing;
25
+ supports_tools?: boolean;
26
+ supports_structured_output?: boolean;
27
+ };
28
+
29
+ type Pricing = {
30
+ input: number;
31
+ output: number;
32
+ };
33
+
34
  export const load: PageLoad = async ({ fetch }) => {
35
  const [modelsRes, routerRes] = await Promise.all([
36
  fetch("/api/models"),
 
38
  ]);
39
 
40
  const models: ApiModelsResponse = await modelsRes.json();
41
+ const routerData = (await routerRes.json()) as RouterData;
42
 
43
  return {
44
  ...models,