Add feature for LLM finetuning from HuggingFace (transformer models only)uniAIDevs/ai#5
![Logo of Sweep](/_next/image?url=%2Flogo.png&w=64&q=75)
Add feature for LLM finetuning from HuggingFace (transformer models only)
uniAIDevs/ai#5
> > >
✓ Completed in 29 minutes, 3 months ago using GPT-4 • Book a call • Report a bug
Progress
Create
gateway/internal/provider/huggingface/huggingface.go
60911bd
1package huggingface
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "io/ioutil"
8 "net/http"
9 "github.com/missingstudio/ai/gateway/internal/provider/base"
10 "github.com/missingstudio/common/errors"
11)
12
13type HuggingFaceProvider struct {
14 APIKey string
15 BaseURL string
16}
17
18func (hfp *HuggingFaceProvider) Info() base.ProviderInfo {
19 return base.ProviderInfo{
20 Name: "HuggingFace",
21 Description: "Provider for interacting with HuggingFace's transformer models",
22 }
23}
24
25func (hfp *HuggingFaceProvider) Models(ctx context.Context) ([]string, error) {
26 url := fmt.Sprintf("%s/models", hfp.BaseURL)
27 req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
28 if err != nil {
29 return nil, err
30 }
31
32 req.Header.Add("Authorization", "Bearer "+hfp.APIKey)
33 client := &http.Client{}
34 resp, err := client.Do(req)
35 if err != nil {
36 return nil, err
37 }
38 defer resp.Body.Close()
39
40 if resp.StatusCode != http.StatusOK {
41 return nil, errors.NewBadRequest("failed to fetch models from HuggingFace")
42 }
43
44 var models []string
45 err = json.NewDecoder(resp.Body).Decode(&models)
46 if err != nil {
47 return nil, err
48 }
49
50 return models, nil
51}
52
53func (hfp *HuggingFaceProvider) InitiateFineTuning(ctx context.Context, model string, parameters map[string]interface{}) (string, error) {
54 url := fmt.Sprintf("%s/fine-tune", hfp.BaseURL)
55 payload, err := json.Marshal(map[string]interface{}{
56 "model": model,
57 "parameters": parameters,
58 })
59 if err != nil {
60 return "", err
61 }
62
63 req, err := http.NewRequestWithContext(ctx, "POST", url, ioutil.NopCloser(bytes.NewReader(payload)))
64 if err != nil {
65 return "", err
66 }
67
68 req.Header.Add("Authorization", "Bearer "+hfp.APIKey)
69 req.Header.Add("Content-Type", "application/json")
70 client := &http.Client{}
71 resp, err := client.Do(req)
72 if err != nil {
73 return "", err
74 }
75 defer resp.Body.Close()
76
77 if resp.StatusCode != http.StatusOK {
78 return "", errors.NewBadRequest("failed to initiate fine-tuning on HuggingFace")
79 }
80
81 var result map[string]string
82 err = json.NewDecoder(resp.Body).Decode(&result)
83 if err != nil {
84 return "", err
85 }
86
87 return result["job_id"], nil
88}
89
90func (hfp *HuggingFaceProvider) RetrieveFineTuningResults(ctx context.Context, jobID string) (map[string]interface{}, error) {
91 url := fmt.Sprintf("%s/fine-tune/%s", hfp.BaseURL, jobID)
92 req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
93 if err != nil {
94 return nil, err
95 }
96
97 req.Header.Add("Authorization", "Bearer "+hfp.APIKey)
98 client := &http.Client{}
99 resp, err := client.Do(req)
100 if err != nil {
101 return nil, err
102 }
103 defer resp.Body.Close()
104
105 if resp.StatusCode != http.StatusOK {
106 return nil, errors.NewBadRequest("failed to retrieve fine-tuning results from HuggingFace")
107 }
108
109 var result map[string]interface{}
110 err = json.NewDecoder(resp.Body).Decode(&result)
111 if err != nil {
112 return nil, err
113 }
114
115 return result, nil
116}
117
- Create a new file
huggingface.go
in thegateway/internal/provider/huggingface/
directory. This file will define theHuggingFaceProvider
struct and implement theProvider
interface fromprovider.go
. - Implement methods for the
Provider
interface, focusing on fine-tuning transformer models. This includes methods for listing available models, initiating fine-tuning jobs, and retrieving the results of fine-tuning. - Import necessary packages for interacting with the HuggingFace API, handling HTTP requests, and processing JSON data.
Run GitHub Actions for
gateway/internal/provider/huggingface/huggingface.go
Ran GitHub Actions for 60911bdcddb9e8d725b85d2cf0f8421262fce932:
Create
gateway/internal/api/v1/finetune.go
82f5593
1package v1
2
3import (
4 "context"
5 "encoding/json"
6 "net/http"
7
8 "connectrpc.com/connect"
9 "github.com/missingstudio/ai/gateway/core/provider"
10 "github.com/missingstudio/ai/gateway/internal/provider/huggingface"
11 "github.com/missingstudio/common/errors"
12 llmv1 "github.com/missingstudio/protos/pkg/llm/v1"
13)
14
15func (s *V1Handler) InitiateFineTuning(ctx context.Context, req *connect.Request[llmv1.FineTuneRequest]) (*connect.Response[llmv1.FineTuneResponse], error) {
16 hfProvider, err := s.iProviderService.GetProvider(provider.Provider{Name: "HuggingFace"})
17 if err != nil {
18 return nil, errors.NewInternal("failed to get HuggingFace provider")
19 }
20
21 jobID, err := hfProvider.(*huggingface.HuggingFaceProvider).InitiateFineTuning(ctx, req.Payload.Model, req.Payload.Parameters)
22 if err != nil {
23 return nil, errors.NewInternal("failed to initiate fine-tuning: " + err.Error())
24 }
25
26 return connect.NewResponse(&llmv1.FineTuneResponse{
27 JobId: jobID,
28 }), nil
29}
30
31func (s *V1Handler) CheckFineTuningStatus(ctx context.Context, req *connect.Request[llmv1.FineTuneStatusRequest]) (*connect.Response[llmv1.FineTuneStatusResponse], error) {
32 hfProvider, err := s.iProviderService.GetProvider(provider.Provider{Name: "HuggingFace"})
33 if err != nil {
34 return nil, errors.NewInternal("failed to get HuggingFace provider")
35 }
36
37 result, err := hfProvider.(*huggingface.HuggingFaceProvider).RetrieveFineTuningResults(ctx, req.Payload.JobId)
38 if err != nil {
39 return nil, errors.NewInternal("failed to retrieve fine-tuning results: " + err.Error())
40 }
41
42 status, ok := result["status"].(string)
43 if !ok {
44 return nil, errors.NewInternal("unexpected response format from HuggingFace")
45 }
46
47 return connect.NewResponse(&llmv1.FineTuneStatusResponse{
48 Status: status,
49 }), nil
50}
51
- Create a new file
finetune.go
in thegateway/internal/api/v1/
directory to handle API requests related to fine-tuning. - Define endpoints for initiating fine-tuning jobs and checking their status. Use the
HuggingFaceProvider
to interact with the HuggingFace API. - Implement handlers for the new endpoints, parsing request data, calling the appropriate methods on the
HuggingFaceProvider
, and formatting responses.
Run GitHub Actions for
gateway/internal/api/v1/finetune.go
Ran GitHub Actions for 82f55933d34ab4ae641421256a30790d745a3cfd:
Modify
gateway/internal/api/v1/v1.go
1package v1
2
3import (
4 "github.com/missingstudio/ai/gateway/core/apikey"
5 "github.com/missingstudio/ai/gateway/core/prompt"
6 "github.com/missingstudio/ai/gateway/core/provider"
7 "github.com/missingstudio/ai/gateway/internal/ingester"
8 iprovider "github.com/missingstudio/ai/gateway/internal/provider"
9 "github.com/missingstudio/protos/pkg/llm/v1/llmv1connect"
10 "github.com/missingstudio/protos/pkg/prompt/v1/promptv1connect"
11)
12
13type V1Handler struct {
14 llmv1connect.UnimplementedLLMServiceHandler
15 promptv1connect.UnimplementedPromptRegistryServiceHandler
16 ingester ingester.Ingester
17 apikeyService *apikey.Service
18 promptService *prompt.Service
19 providerService *provider.Service
20 iProviderService *iprovider.Service
21}
22
23func NewHandlerV1(
24 ingester ingester.Ingester,
25 apikeyService *apikey.Service,
26 promptService *prompt.Service,
27 providerService *provider.Service,
28 iProviderService *iprovider.Service,
29) *V1Handler {
30 return &V1Handler{
31 ingester: ingester,
32 apikeyService: apikeyService,
33 promptService: promptService,
34 providerService: providerService,
35 iProviderService: iProviderService,
36 }
37}
38
- Register the new fine-tuning endpoints defined in
finetune.go
with the router inv1.go
. - Ensure that the endpoints are properly authenticated and validate request data before processing.
Run GitHub Actions for
gateway/internal/api/v1/v1.go
ModifyChanged
playgrounds/apps/studio/app/(llm)/playground/hooks/useModelFetch.tsx
Changed playgrounds/apps/studio/app/(llm)/playground/hooks/useModelFetch.tsx
in 3f04b91
14 | const BASE_URL = process.env.NEXT_PUBLIC_GATEWAY_URL ?? "http://localhost:3000"; | 14 | const BASE_URL = process.env.NEXT_PUBLIC_GATEWAY_URL ?? "http://localhost:3000"; |
15 | export function useModelFetch() { | 15 | export function useModelFetch() { |
16 | const [providers, setProviders] = useState<ModelType[]>([]); | 16 | const [providers, setProviders] = useState<ModelType[]>([]); |
17 | const [isFineTuning, setIsFineTuning] = useState<boolean>(false); | ||
17 | 18 | ||
18 | useEffect(() => { | 19 | useEffect(() => { |
20 | const fetchEndpoint = isFineTuning ? `${BASE_URL}/api/v1/finetune/models` : `${BASE_URL}/api/v1/models`; | ||
19 | async function fetchModels() { | 21 | async function fetchModels() { |
20 | try { | 22 | try { |
21 | const response = await fetch(`${BASE_URL}/api/v1/models`); | 23 | const response = await fetch(fetchEndpoint); |
22 | const { models } = await response.json(); | 24 | const { models } = await response.json(); |
23 | const fetchedProviders: ModelType[] = Object.keys(models).map( | 25 | const fetchedProviders: ModelType[] = Object.keys(models).map( |
24 | (key) => ({ | 26 | (key) => ({ |
- Update the
useModelFetch
hook to include an option for selecting transformer models for fine-tuning. This may involve adding a new state to track whether the user is interested in fine-tuning and fetching additional data from the backend if necessary. - Adjust the API call to include requests to the new fine-tuning endpoints as needed.
Run GitHub Actions for
playgrounds/apps/studio/app/(llm)/playground/hooks/useModelFetch.tsx
Ran GitHub Actions for 3f04b91788cb909b639228a1376a27f2ebcfa2cf:
ModifyChanged
playgrounds/apps/studio/app/(llm)/playground/components/modelselector.tsx
Changed playgrounds/apps/studio/app/(llm)/playground/components/modelselector.tsx
in bd0d237
26 | 26 | ||
27 | export default function ModelSelector(props: ModelSelectorProps) { | 27 | export default function ModelSelector(props: ModelSelectorProps) { |
28 | const [open, setOpen] = React.useState(false); | 28 | const [open, setOpen] = React.useState(false); |
29 | const { providers } = useModelFetch(); | 29 | const [isFineTuning, setIsFineTuning] = React.useState(false); |
30 | const { providers } = useModelFetch(isFineTuning); | ||
30 | const { model, setModel, setProvider } = useStore(); | 31 | const { model, setModel, setProvider } = useStore(); |
31 | 32 | ||
33 | const toggleFineTuning = () => setIsFineTuning(!isFineTuning); | ||
34 | |||
32 | return ( | 35 | return ( |
33 | <div className="flex items-center gap-2"> | 36 | <div className="flex items-center gap-2"> |
34 | <Label htmlFor="model">Model: </Label> | 37 | <Label htmlFor="model">Model: </Label> |
38 | <Button variant="outline" onClick={toggleFineTuning}> | ||
39 | {isFineTuning ? 'Select for Fine-Tuning' : 'Select Model'} | ||
40 | </Button> | ||
35 | <Popover open={open} onOpenChange={setOpen} {...props}> | 41 | <Popover open={open} onOpenChange={setOpen} {...props}> |
36 | <PopoverTrigger asChild> | 42 | <PopoverTrigger asChild> |
37 | <Button | 43 | <Button |
- Modify the
ModelSelector
component to allow users to select transformer models for fine-tuning. This includes UI changes to present fine-tuning options and possibly a new UI component for selecting or uploading datasets for fine-tuning. - Ensure that the component interacts correctly with the updated
useModelFetch
hook to fetch and display the relevant options.
Run GitHub Actions for
playgrounds/apps/studio/app/(llm)/playground/components/modelselector.tsx
Ran GitHub Actions for bd0d2375e35ba4f7425000471c68d136c3864e24:
Plan
This is based on the results of the Planning step. The plan may expand from failed GitHub Actions runs.
Create
gateway/internal/provider/huggingface/huggingface.go
60911bd
1package huggingface
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "io/ioutil"
8 "net/http"
9 "github.com/missingstudio/ai/gateway/internal/provider/base"
10 "github.com/missingstudio/common/errors"
11)
12
13type HuggingFaceProvider struct {
14 APIKey string
15 BaseURL string
16}
17
18func (hfp *HuggingFaceProvider) Info() base.ProviderInfo {
19 return base.ProviderInfo{
20 Name: "HuggingFace",
21 Description: "Provider for interacting with HuggingFace's transformer models",
22 }
23}
24
25func (hfp *HuggingFaceProvider) Models(ctx context.Context) ([]string, error) {
26 url := fmt.Sprintf("%s/models", hfp.BaseURL)
27 req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
28 if err != nil {
29 return nil, err
30 }
31
32 req.Header.Add("Authorization", "Bearer "+hfp.APIKey)
33 client := &http.Client{}
34 resp, err := client.Do(req)
35 if err != nil {
36 return nil, err
37 }
38 defer resp.Body.Close()
39
40 if resp.StatusCode != http.StatusOK {
41 return nil, errors.NewBadRequest("failed to fetch models from HuggingFace")
42 }
43
44 var models []string
45 err = json.NewDecoder(resp.Body).Decode(&models)
46 if err != nil {
47 return nil, err
48 }
49
50 return models, nil
51}
52
53func (hfp *HuggingFaceProvider) InitiateFineTuning(ctx context.Context, model string, parameters map[string]interface{}) (string, error) {
54 url := fmt.Sprintf("%s/fine-tune", hfp.BaseURL)
55 payload, err := json.Marshal(map[string]interface{}{
56 "model": model,
57 "parameters": parameters,
58 })
59 if err != nil {
60 return "", err
61 }
62
63 req, err := http.NewRequestWithContext(ctx, "POST", url, ioutil.NopCloser(bytes.NewReader(payload)))
64 if err != nil {
65 return "", err
66 }
67
68 req.Header.Add("Authorization", "Bearer "+hfp.APIKey)
69 req.Header.Add("Content-Type", "application/json")
70 client := &http.Client{}
71 resp, err := client.Do(req)
72 if err != nil {
73 return "", err
74 }
75 defer resp.Body.Close()
76
77 if resp.StatusCode != http.StatusOK {
78 return "", errors.NewBadRequest("failed to initiate fine-tuning on HuggingFace")
79 }
80
81 var result map[string]string
82 err = json.NewDecoder(resp.Body).Decode(&result)
83 if err != nil {
84 return "", err
85 }
86
87 return result["job_id"], nil
88}
89
90func (hfp *HuggingFaceProvider) RetrieveFineTuningResults(ctx context.Context, jobID string) (map[string]interface{}, error) {
91 url := fmt.Sprintf("%s/fine-tune/%s", hfp.BaseURL, jobID)
92 req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
93 if err != nil {
94 return nil, err
95 }
96
97 req.Header.Add("Authorization", "Bearer "+hfp.APIKey)
98 client := &http.Client{}
99 resp, err := client.Do(req)
100 if err != nil {
101 return nil, err
102 }
103 defer resp.Body.Close()
104
105 if resp.StatusCode != http.StatusOK {
106 return nil, errors.NewBadRequest("failed to retrieve fine-tuning results from HuggingFace")
107 }
108
109 var result map[string]interface{}
110 err = json.NewDecoder(resp.Body).Decode(&result)
111 if err != nil {
112 return nil, err
113 }
114
115 return result, nil
116}
117
Run GitHub Actions for
gateway/internal/provider/huggingface/huggingface.go
Create
gateway/internal/api/v1/finetune.go
82f5593
1package v1
2
3import (
4 "context"
5 "encoding/json"
6 "net/http"
7
8 "connectrpc.com/connect"
9 "github.com/missingstudio/ai/gateway/core/provider"
10 "github.com/missingstudio/ai/gateway/internal/provider/huggingface"
11 "github.com/missingstudio/common/errors"
12 llmv1 "github.com/missingstudio/protos/pkg/llm/v1"
13)
14
15func (s *V1Handler) InitiateFineTuning(ctx context.Context, req *connect.Request[llmv1.FineTuneRequest]) (*connect.Response[llmv1.FineTuneResponse], error) {
16 hfProvider, err := s.iProviderService.GetProvider(provider.Provider{Name: "HuggingFace"})
17 if err != nil {
18 return nil, errors.NewInternal("failed to get HuggingFace provider")
19 }
20
21 jobID, err := hfProvider.(*huggingface.HuggingFaceProvider).InitiateFineTuning(ctx, req.Payload.Model, req.Payload.Parameters)
22 if err != nil {
23 return nil, errors.NewInternal("failed to initiate fine-tuning: " + err.Error())
24 }
25
26 return connect.NewResponse(&llmv1.FineTuneResponse{
27 JobId: jobID,
28 }), nil
29}
30
31func (s *V1Handler) CheckFineTuningStatus(ctx context.Context, req *connect.Request[llmv1.FineTuneStatusRequest]) (*connect.Response[llmv1.FineTuneStatusResponse], error) {
32 hfProvider, err := s.iProviderService.GetProvider(provider.Provider{Name: "HuggingFace"})
33 if err != nil {
34 return nil, errors.NewInternal("failed to get HuggingFace provider")
35 }
36
37 result, err := hfProvider.(*huggingface.HuggingFaceProvider).RetrieveFineTuningResults(ctx, req.Payload.JobId)
38 if err != nil {
39 return nil, errors.NewInternal("failed to retrieve fine-tuning results: " + err.Error())
40 }
41
42 status, ok := result["status"].(string)
43 if !ok {
44 return nil, errors.NewInternal("unexpected response format from HuggingFace")
45 }
46
47 return connect.NewResponse(&llmv1.FineTuneStatusResponse{
48 Status: status,
49 }), nil
50}
51
Run GitHub Actions for
gateway/internal/api/v1/finetune.go
Run GitHub Actions for
gateway/internal/api/v1/v1.go
Run GitHub Actions for
playgrounds/apps/studio/app/(llm)/playground/hooks/useModelFetch.tsx
Run GitHub Actions for
playgrounds/apps/studio/app/(llm)/playground/components/modelselector.tsx
Code Snippets Found
This is based on the results of the Searching step.
gateway/internal/provider/openai/openai.go:0-123
1package openai
2
3import (
4 "bufio"
5 "bytes"
6 "context"
7 "encoding/json"
8 "fmt"
9 "net/http"
10 "strings"
11
12 "github.com/missingstudio/ai/gateway/core/chat"
13 "github.com/missingstudio/ai/gateway/core/provider"
14 "github.com/missingstudio/ai/gateway/internal/requester"
15)
16
17var OpenAIModels = []string{
18 "gpt-4-0125-preview",
19 "gpt-4-turbo-preview",
20 "gpt-4-1106-preview",
21 "gpt-4-vision-preview",
22 "gpt-4-1106-vision-preview",
23 "gpt-4",
24 "gpt-4-0613",
25 "gpt-4-32k",
26 "gpt-4-32k-0613",
27 "gpt-3.5-turbo-0125",
28 "gpt-3.5-turbo",
29 "gpt-3.5-turbo-1106",
30 "gpt-3.5-turbo-instruct",
31}
32
33func (oai *openAIProvider) ChatCompletions(ctx context.Context, payload *chat.ChatCompletionRequest) (*chat.ChatCompletionResponse, error) {
34 client := requester.NewHTTPClient()
35
36 payload.Stream = false
37 rawPayload, err := json.Marshal(payload)
38 if err != nil {
39 return nil, fmt.Errorf("unable to marshal openai chat request payload: %w", err)
40 }
41
42 requestURL := fmt.Sprintf("%s%s", oai.config.BaseURL, oai.config.ChatCompletions)
43 req, err := http.NewRequestWithContext(ctx, "POST", requestURL, bytes.NewReader(rawPayload))
44 if err != nil {
45 return nil, err
46 }
47
48 req = oai.AddDefaultHeaders(req, provider.AuthorizationHeader)
49 resp, err := client.SendRequestRaw(req)
50 if err != nil {
51 return nil, err
52 }
53
54 data := &chat.ChatCompletionResponse{}
55 if err := json.NewDecoder(resp.Body).Decode(data); err != nil {
56 return nil, err
57 }
58
59 return data, nil
60}
61
62func (*openAIProvider) Models() []string {
63 return OpenAIModels
64}
65
66func (oai *openAIProvider) AddDefaultHeaders(req *http.Request, key string) *http.Request {
67 providerConfigMap := oai.provider.GetConfig([]string{
68 provider.AuthorizationHeader,
69 })
70
71 var authorizationHeader string
72 if val, ok := providerConfigMap[provider.AuthorizationHeader].(string); ok && val != "" {
73 authorizationHeader = val
74 }
75
76 req.Header.Add("Content-Type", "application/json")
77 req.Header.Add("Authorization", authorizationHeader)
78 return req
79}
80
81func (oai *openAIProvider) StreamChatCompletions(ctx context.Context, payload *chat.ChatCompletionRequest, stream chan []byte) error {
82 client := requester.NewHTTPClient()
83
84 payload.Stream = true
85 rawPayload, err := json.Marshal(payload)
86 if err != nil {
87 return fmt.Errorf("unable to marshal openai chat request payload: %w", err)
88 }
89
90 requestURL := fmt.Sprintf("%s%s", oai.config.BaseURL, oai.config.ChatCompletions)
91 req, err := http.NewRequestWithContext(ctx, "POST", requestURL, bytes.NewReader(rawPayload))
92 if err != nil {
93 return err
94 }
95
96 req = oai.AddDefaultHeaders(req, provider.AuthorizationHeader)
97 resp, err := client.SendRequestRaw(req)
98 if err != nil {
99 return err
100 }
101 defer resp.Body.Close()
102
103 scanner := bufio.NewScanner(resp.Body)
104 for scanner.Scan() {
105 stream <- scanner.Bytes()
106
107 line := scanner.Text()
108 if strings.HasPrefix(line, "data:") {
109 event := strings.TrimPrefix(line, "data:")
110 event = strings.TrimSpace(strings.Trim(event, "\n"))
111 if strings.Contains(line, "[DONE]") {
112 break
113 }
114
115 var data chat.ChatCompletionResponse
116 if err := json.Unmarshal([]byte(event), &data); err != nil {
117 return err
118 }
119 }
120 }
121
122 close(stream)
123 return nil
gateway/internal/api/v1/models.go:0-46
1package v1
2
3import (
4 "context"
5
6 "connectrpc.com/connect"
7 "github.com/missingstudio/ai/gateway/core/provider"
8 "github.com/missingstudio/ai/gateway/internal/provider/base"
9 llmv1 "github.com/missingstudio/protos/pkg/llm/v1"
10)
11
12func (s *V1Handler) ListModels(ctx context.Context, req *connect.Request[llmv1.ModelRequest]) (*connect.Response[llmv1.ModelResponse], error) {
13 allProviderModels := map[string]*llmv1.ProviderModels{}
14
15 for name := range base.ProviderRegistry {
16 // Check if the provider is healthy before fetching models
17 if !router.DefaultHealthChecker{}.IsHealthy(name) {
18 continue
19 }
20
21 provider, err := s.iProviderService.GetProvider(provider.Provider{Name: name})
22 if err != nil {
23 continue
24 }
25
26 providerInfo := provider.Info()
27 providerModels := provider.Models()
28 providerName := providerInfo.Name
29
30 var models []*llmv1.Model
31 for _, val := range providerModels {
32 models = append(models, &llmv1.Model{
33 Name: val,
34 Value: val,
35 })
36 }
37
38 allProviderModels[name] = &llmv1.ProviderModels{
39 Name: providerName,
40 Models: models,
41 }
42 }
43
44 return connect.NewResponse(&llmv1.ModelResponse{
45 Models: allProviderModels,
46 }), nil
playgrounds/apps/studio/app/(llm)/playground/hooks/useModelFetch.tsx:0-38
1import { useEffect, useState } from "react";
2import { toast } from "sonner";
3
4interface ModelType {
5 name: string;
6 models: Model[];
7}
8
9interface Model {
10 name: string;
11 value: string;
12}
13
14const BASE_URL = process.env.NEXT_PUBLIC_GATEWAY_URL ?? "http://localhost:3000";
15export function useModelFetch() {
16 const [providers, setProviders] = useState<ModelType[]>([]);
17
18 useEffect(() => {
19 async function fetchModels() {
20 try {
21 const response = await fetch(`${BASE_URL}/api/v1/models`);
22 const { models } = await response.json();
23 const fetchedProviders: ModelType[] = Object.keys(models).map(
24 (key) => ({
25 name: models[key].name,
26 models: models[key].models,
27 })
28 );
29 setProviders(fetchedProviders);
30 } catch (e) {
31 toast.error("AI Studio is not running");
32 }
33 }
34
35 fetchModels();
36 }, []);
37
38 return { providers };
playgrounds/apps/studio/app/(llm)/playground/components/modelselector.tsx:0-73
1"use client";
2import { CaretSortIcon } from "@radix-ui/react-icons";
3import { PopoverProps } from "@radix-ui/react-popover";
4import React from "react";
5
6import {
7 Popover,
8 PopoverContent,
9 PopoverTrigger,
10} from "@missingstudio/ui/popover";
11
12import { Button } from "@missingstudio/ui/button";
13import {
14 Command,
15 CommandEmpty,
16 CommandGroup,
17 CommandInput,
18 CommandList,
19} from "@missingstudio/ui/command";
20import { Label } from "@missingstudio/ui/label";
21import { ModelItem } from "~/app/(llm)/playground/components/modelitem";
22import { useModelFetch } from "~/app/(llm)/playground/hooks/useModelFetch";
23import { useStore } from "~/app/(llm)/playground/store";
24
25interface ModelSelectorProps extends PopoverProps {}
26
27export default function ModelSelector(props: ModelSelectorProps) {
28 const [open, setOpen] = React.useState(false);
29 const { providers } = useModelFetch();
30 const { model, setModel, setProvider } = useStore();
31
32 return (
33 <div className="flex items-center gap-2">
34 <Label htmlFor="model">Model: </Label>
35 <Popover open={open} onOpenChange={setOpen} {...props}>
36 <PopoverTrigger asChild>
37 <Button
38 variant="outline"
39 role="combobox"
40 aria-expanded={open}
41 aria-label="Select a model"
42 className="w-full justify-between"
43 >
44 {model ? model : "Select a model..."}
45 <CaretSortIcon className="ml-2 h-4 w-4 shrink-0 opacity-50" />
46 </Button>
47 </PopoverTrigger>
48 <PopoverContent align="end" className="w-[250px] p-0">
49 <Command loop>
50 <CommandList className="h-[var(--cmdk-list-height)] max-h-[400px]">
51 <CommandInput placeholder="Search Models..." />
52 <CommandEmpty>No Models found.</CommandEmpty>
53 {providers.map((provider) => (
54 <CommandGroup key={provider.name} heading={provider.name}>
55 {provider.models.map((singleModel, index) => (
56 <ModelItem
57 key={`${index}_${provider.name}_${singleModel.value}`}
58 id={`${singleModel.value}`}
59 isSelected={model === singleModel.value}
60 onSelect={() => {
61 setProvider(provider.name);
62 setModel(singleModel.value);
63 setOpen(false);
64 }}
65 />
66 ))}
67 </CommandGroup>
68 ))}
69 </CommandList>
70 </Command>
71 </PopoverContent>
72 </Popover>
73 </div>