Logo of Sweep
Add feature for LLM finetuning from HuggingFace (transformer models only)uniAIDevs/ai#5

> > >

✓ Completed in 29 minutes, 3 months ago using GPT-4  •   Book a call  •   Report a bug


Progress

  Creategateway/internal/provider/huggingface/huggingface.go60911bd 
1package huggingface
2
3import (
4	"context"
5	"encoding/json"
6	"fmt"
7	"io/ioutil"
8	"net/http"
9	"github.com/missingstudio/ai/gateway/internal/provider/base"
10	"github.com/missingstudio/common/errors"
11)
12
13type HuggingFaceProvider struct {
14	APIKey string
15	BaseURL string
16}
17
18func (hfp *HuggingFaceProvider) Info() base.ProviderInfo {
19	return base.ProviderInfo{
20		Name: "HuggingFace",
21		Description: "Provider for interacting with HuggingFace's transformer models",
22	}
23}
24
25func (hfp *HuggingFaceProvider) Models(ctx context.Context) ([]string, error) {
26	url := fmt.Sprintf("%s/models", hfp.BaseURL)
27	req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
28	if err != nil {
29		return nil, err
30	}
31
32	req.Header.Add("Authorization", "Bearer "+hfp.APIKey)
33	client := &http.Client{}
34	resp, err := client.Do(req)
35	if err != nil {
36		return nil, err
37	}
38	defer resp.Body.Close()
39
40	if resp.StatusCode != http.StatusOK {
41		return nil, errors.NewBadRequest("failed to fetch models from HuggingFace")
42	}
43
44	var models []string
45	err = json.NewDecoder(resp.Body).Decode(&models)
46	if err != nil {
47		return nil, err
48	}
49
50	return models, nil
51}
52
53func (hfp *HuggingFaceProvider) InitiateFineTuning(ctx context.Context, model string, parameters map[string]interface{}) (string, error) {
54	url := fmt.Sprintf("%s/fine-tune", hfp.BaseURL)
55	payload, err := json.Marshal(map[string]interface{}{
56		"model": model,
57		"parameters": parameters,
58	})
59	if err != nil {
60		return "", err
61	}
62
63	req, err := http.NewRequestWithContext(ctx, "POST", url, ioutil.NopCloser(bytes.NewReader(payload)))
64	if err != nil {
65		return "", err
66	}
67
68	req.Header.Add("Authorization", "Bearer "+hfp.APIKey)
69	req.Header.Add("Content-Type", "application/json")
70	client := &http.Client{}
71	resp, err := client.Do(req)
72	if err != nil {
73		return "", err
74	}
75	defer resp.Body.Close()
76
77	if resp.StatusCode != http.StatusOK {
78		return "", errors.NewBadRequest("failed to initiate fine-tuning on HuggingFace")
79	}
80
81	var result map[string]string
82	err = json.NewDecoder(resp.Body).Decode(&result)
83	if err != nil {
84		return "", err
85	}
86
87	return result["job_id"], nil
88}
89
90func (hfp *HuggingFaceProvider) RetrieveFineTuningResults(ctx context.Context, jobID string) (map[string]interface{}, error) {
91	url := fmt.Sprintf("%s/fine-tune/%s", hfp.BaseURL, jobID)
92	req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
93	if err != nil {
94		return nil, err
95	}
96
97	req.Header.Add("Authorization", "Bearer "+hfp.APIKey)
98	client := &http.Client{}
99	resp, err := client.Do(req)
100	if err != nil {
101		return nil, err
102	}
103	defer resp.Body.Close()
104
105	if resp.StatusCode != http.StatusOK {
106		return nil, errors.NewBadRequest("failed to retrieve fine-tuning results from HuggingFace")
107	}
108
109	var result map[string]interface{}
110	err = json.NewDecoder(resp.Body).Decode(&result)
111	if err != nil {
112		return nil, err
113	}
114
115	return result, nil
116}
117
  • Create a new file huggingface.go in the gateway/internal/provider/huggingface/ directory. This file will define the HuggingFaceProvider struct and implement the Provider interface from provider.go.
  • Implement methods for the Provider interface, focusing on fine-tuning transformer models. This includes methods for listing available models, initiating fine-tuning jobs, and retrieving the results of fine-tuning.
  • Import necessary packages for interacting with the HuggingFace API, handling HTTP requests, and processing JSON data.
  Creategateway/internal/api/v1/finetune.go82f5593 
1package v1
2
3import (
4	"context"
5	"encoding/json"
6	"net/http"
7
8	"connectrpc.com/connect"
9	"github.com/missingstudio/ai/gateway/core/provider"
10	"github.com/missingstudio/ai/gateway/internal/provider/huggingface"
11	"github.com/missingstudio/common/errors"
12	llmv1 "github.com/missingstudio/protos/pkg/llm/v1"
13)
14
15func (s *V1Handler) InitiateFineTuning(ctx context.Context, req *connect.Request[llmv1.FineTuneRequest]) (*connect.Response[llmv1.FineTuneResponse], error) {
16	hfProvider, err := s.iProviderService.GetProvider(provider.Provider{Name: "HuggingFace"})
17	if err != nil {
18		return nil, errors.NewInternal("failed to get HuggingFace provider")
19	}
20
21	jobID, err := hfProvider.(*huggingface.HuggingFaceProvider).InitiateFineTuning(ctx, req.Payload.Model, req.Payload.Parameters)
22	if err != nil {
23		return nil, errors.NewInternal("failed to initiate fine-tuning: " + err.Error())
24	}
25
26	return connect.NewResponse(&llmv1.FineTuneResponse{
27		JobId: jobID,
28	}), nil
29}
30
31func (s *V1Handler) CheckFineTuningStatus(ctx context.Context, req *connect.Request[llmv1.FineTuneStatusRequest]) (*connect.Response[llmv1.FineTuneStatusResponse], error) {
32	hfProvider, err := s.iProviderService.GetProvider(provider.Provider{Name: "HuggingFace"})
33	if err != nil {
34		return nil, errors.NewInternal("failed to get HuggingFace provider")
35	}
36
37	result, err := hfProvider.(*huggingface.HuggingFaceProvider).RetrieveFineTuningResults(ctx, req.Payload.JobId)
38	if err != nil {
39		return nil, errors.NewInternal("failed to retrieve fine-tuning results: " + err.Error())
40	}
41
42	status, ok := result["status"].(string)
43	if !ok {
44		return nil, errors.NewInternal("unexpected response format from HuggingFace")
45	}
46
47	return connect.NewResponse(&llmv1.FineTuneStatusResponse{
48		Status: status,
49	}), nil
50}
51
  • Create a new file finetune.go in the gateway/internal/api/v1/ directory to handle API requests related to fine-tuning.
  • Define endpoints for initiating fine-tuning jobs and checking their status. Use the HuggingFaceProvider to interact with the HuggingFace API.
  • Implement handlers for the new endpoints, parsing request data, calling the appropriate methods on the HuggingFaceProvider, and formatting responses.
  Modifygateway/internal/api/v1/v1.go 
1package v1
2
3import (
4	"github.com/missingstudio/ai/gateway/core/apikey"
5	"github.com/missingstudio/ai/gateway/core/prompt"
6	"github.com/missingstudio/ai/gateway/core/provider"
7	"github.com/missingstudio/ai/gateway/internal/ingester"
8	iprovider "github.com/missingstudio/ai/gateway/internal/provider"
9	"github.com/missingstudio/protos/pkg/llm/v1/llmv1connect"
10	"github.com/missingstudio/protos/pkg/prompt/v1/promptv1connect"
11)
12
13type V1Handler struct {
14	llmv1connect.UnimplementedLLMServiceHandler
15	promptv1connect.UnimplementedPromptRegistryServiceHandler
16	ingester         ingester.Ingester
17	apikeyService    *apikey.Service
18	promptService    *prompt.Service
19	providerService  *provider.Service
20	iProviderService *iprovider.Service
21}
22
23func NewHandlerV1(
24	ingester ingester.Ingester,
25	apikeyService *apikey.Service,
26	promptService *prompt.Service,
27	providerService *provider.Service,
28	iProviderService *iprovider.Service,
29) *V1Handler {
30	return &V1Handler{
31		ingester:         ingester,
32		apikeyService:    apikeyService,
33		promptService:    promptService,
34		providerService:  providerService,
35		iProviderService: iProviderService,
36	}
37}
38
  • Register the new fine-tuning endpoints defined in finetune.go with the router in v1.go.
  • Ensure that the endpoints are properly authenticated and validate request data before processing.
  Run GitHub Actions forgateway/internal/api/v1/v1.go 
  Modifyplaygrounds/apps/studio/app/(llm)/playground/hooks/useModelFetch.tsx 

Changed playgrounds/apps/studio/app/(llm)/playground/hooks/useModelFetch.tsx in 3f04b91    

14const BASE_URL = process.env.NEXT_PUBLIC_GATEWAY_URL ?? "http://localhost:3000";14const BASE_URL = process.env.NEXT_PUBLIC_GATEWAY_URL ?? "http://localhost:3000";
15export function useModelFetch() {15export function useModelFetch() {
16 const [providers, setProviders] = useState<ModelType[]>([]);16 const [providers, setProviders] = useState<ModelType[]>([]);
17const [isFineTuning, setIsFineTuning] = useState<boolean>(false);
17 18
18 useEffect(() => {19 useEffect(() => {
20 const fetchEndpoint = isFineTuning ? `${BASE_URL}/api/v1/finetune/models` : `${BASE_URL}/api/v1/models`;
19 async function fetchModels() {21 async function fetchModels() {
20 try {22 try {
21 const response = await fetch(`${BASE_URL}/api/v1/models`);23 const response = await fetch(fetchEndpoint);
22 const { models } = await response.json();24 const { models } = await response.json();
23 const fetchedProviders: ModelType[] = Object.keys(models).map(25 const fetchedProviders: ModelType[] = Object.keys(models).map(
24 (key) => ({26 (key) => ({
  • Update the useModelFetch hook to include an option for selecting transformer models for fine-tuning. This may involve adding a new state to track whether the user is interested in fine-tuning and fetching additional data from the backend if necessary.
  • Adjust the API call to include requests to the new fine-tuning endpoints as needed.
  Modifyplaygrounds/apps/studio/app/(llm)/playground/components/modelselector.tsx 

Changed playgrounds/apps/studio/app/(llm)/playground/components/modelselector.tsx in bd0d237    

26 26
27export default function ModelSelector(props: ModelSelectorProps) {27export default function ModelSelector(props: ModelSelectorProps) {
28 const [open, setOpen] = React.useState(false);28 const [open, setOpen] = React.useState(false);
29 const { providers } = useModelFetch();29 const [isFineTuning, setIsFineTuning] = React.useState(false);
30 const { providers } = useModelFetch(isFineTuning);
30 const { model, setModel, setProvider } = useStore();31 const { model, setModel, setProvider } = useStore();
31 32
33 const toggleFineTuning = () => setIsFineTuning(!isFineTuning);
34
32 return (35 return (
33 <div className="flex items-center gap-2">36 <div className="flex items-center gap-2">
34 <Label htmlFor="model">Model: </Label>37 <Label htmlFor="model">Model: </Label>
38 <Button variant="outline" onClick={toggleFineTuning}>
39 {isFineTuning ? 'Select for Fine-Tuning' : 'Select Model'}
40 </Button>
35 <Popover open={open} onOpenChange={setOpen} {...props}>41 <Popover open={open} onOpenChange={setOpen} {...props}>
36 <PopoverTrigger asChild>42 <PopoverTrigger asChild>
37 <Button43 <Button
  • Modify the ModelSelector component to allow users to select transformer models for fine-tuning. This includes UI changes to present fine-tuning options and possibly a new UI component for selecting or uploading datasets for fine-tuning.
  • Ensure that the component interacts correctly with the updated useModelFetch hook to fetch and display the relevant options.

Plan

This is based on the results of the Planning step. The plan may expand from failed GitHub Actions runs.

  Creategateway/internal/provider/huggingface/huggingface.go60911bd 
1package huggingface
2
3import (
4	"context"
5	"encoding/json"
6	"fmt"
7	"io/ioutil"
8	"net/http"
9	"github.com/missingstudio/ai/gateway/internal/provider/base"
10	"github.com/missingstudio/common/errors"
11)
12
13type HuggingFaceProvider struct {
14	APIKey string
15	BaseURL string
16}
17
18func (hfp *HuggingFaceProvider) Info() base.ProviderInfo {
19	return base.ProviderInfo{
20		Name: "HuggingFace",
21		Description: "Provider for interacting with HuggingFace's transformer models",
22	}
23}
24
25func (hfp *HuggingFaceProvider) Models(ctx context.Context) ([]string, error) {
26	url := fmt.Sprintf("%s/models", hfp.BaseURL)
27	req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
28	if err != nil {
29		return nil, err
30	}
31
32	req.Header.Add("Authorization", "Bearer "+hfp.APIKey)
33	client := &http.Client{}
34	resp, err := client.Do(req)
35	if err != nil {
36		return nil, err
37	}
38	defer resp.Body.Close()
39
40	if resp.StatusCode != http.StatusOK {
41		return nil, errors.NewBadRequest("failed to fetch models from HuggingFace")
42	}
43
44	var models []string
45	err = json.NewDecoder(resp.Body).Decode(&models)
46	if err != nil {
47		return nil, err
48	}
49
50	return models, nil
51}
52
53func (hfp *HuggingFaceProvider) InitiateFineTuning(ctx context.Context, model string, parameters map[string]interface{}) (string, error) {
54	url := fmt.Sprintf("%s/fine-tune", hfp.BaseURL)
55	payload, err := json.Marshal(map[string]interface{}{
56		"model": model,
57		"parameters": parameters,
58	})
59	if err != nil {
60		return "", err
61	}
62
63	req, err := http.NewRequestWithContext(ctx, "POST", url, ioutil.NopCloser(bytes.NewReader(payload)))
64	if err != nil {
65		return "", err
66	}
67
68	req.Header.Add("Authorization", "Bearer "+hfp.APIKey)
69	req.Header.Add("Content-Type", "application/json")
70	client := &http.Client{}
71	resp, err := client.Do(req)
72	if err != nil {
73		return "", err
74	}
75	defer resp.Body.Close()
76
77	if resp.StatusCode != http.StatusOK {
78		return "", errors.NewBadRequest("failed to initiate fine-tuning on HuggingFace")
79	}
80
81	var result map[string]string
82	err = json.NewDecoder(resp.Body).Decode(&result)
83	if err != nil {
84		return "", err
85	}
86
87	return result["job_id"], nil
88}
89
90func (hfp *HuggingFaceProvider) RetrieveFineTuningResults(ctx context.Context, jobID string) (map[string]interface{}, error) {
91	url := fmt.Sprintf("%s/fine-tune/%s", hfp.BaseURL, jobID)
92	req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
93	if err != nil {
94		return nil, err
95	}
96
97	req.Header.Add("Authorization", "Bearer "+hfp.APIKey)
98	client := &http.Client{}
99	resp, err := client.Do(req)
100	if err != nil {
101		return nil, err
102	}
103	defer resp.Body.Close()
104
105	if resp.StatusCode != http.StatusOK {
106		return nil, errors.NewBadRequest("failed to retrieve fine-tuning results from HuggingFace")
107	}
108
109	var result map[string]interface{}
110	err = json.NewDecoder(resp.Body).Decode(&result)
111	if err != nil {
112		return nil, err
113	}
114
115	return result, nil
116}
117
  Creategateway/internal/api/v1/finetune.go82f5593 
1package v1
2
3import (
4	"context"
5	"encoding/json"
6	"net/http"
7
8	"connectrpc.com/connect"
9	"github.com/missingstudio/ai/gateway/core/provider"
10	"github.com/missingstudio/ai/gateway/internal/provider/huggingface"
11	"github.com/missingstudio/common/errors"
12	llmv1 "github.com/missingstudio/protos/pkg/llm/v1"
13)
14
15func (s *V1Handler) InitiateFineTuning(ctx context.Context, req *connect.Request[llmv1.FineTuneRequest]) (*connect.Response[llmv1.FineTuneResponse], error) {
16	hfProvider, err := s.iProviderService.GetProvider(provider.Provider{Name: "HuggingFace"})
17	if err != nil {
18		return nil, errors.NewInternal("failed to get HuggingFace provider")
19	}
20
21	jobID, err := hfProvider.(*huggingface.HuggingFaceProvider).InitiateFineTuning(ctx, req.Payload.Model, req.Payload.Parameters)
22	if err != nil {
23		return nil, errors.NewInternal("failed to initiate fine-tuning: " + err.Error())
24	}
25
26	return connect.NewResponse(&llmv1.FineTuneResponse{
27		JobId: jobID,
28	}), nil
29}
30
31func (s *V1Handler) CheckFineTuningStatus(ctx context.Context, req *connect.Request[llmv1.FineTuneStatusRequest]) (*connect.Response[llmv1.FineTuneStatusResponse], error) {
32	hfProvider, err := s.iProviderService.GetProvider(provider.Provider{Name: "HuggingFace"})
33	if err != nil {
34		return nil, errors.NewInternal("failed to get HuggingFace provider")
35	}
36
37	result, err := hfProvider.(*huggingface.HuggingFaceProvider).RetrieveFineTuningResults(ctx, req.Payload.JobId)
38	if err != nil {
39		return nil, errors.NewInternal("failed to retrieve fine-tuning results: " + err.Error())
40	}
41
42	status, ok := result["status"].(string)
43	if !ok {
44		return nil, errors.NewInternal("unexpected response format from HuggingFace")
45	}
46
47	return connect.NewResponse(&llmv1.FineTuneStatusResponse{
48		Status: status,
49	}), nil
50}
51
  Run GitHub Actions forgateway/internal/api/v1/finetune.go 
  Run GitHub Actions forgateway/internal/api/v1/v1.go 

Code Snippets Found

This is based on the results of the Searching step.

gateway/internal/provider/openai/openai.go:0-123 
1package openai
2
3import (
4	"bufio"
5	"bytes"
6	"context"
7	"encoding/json"
8	"fmt"
9	"net/http"
10	"strings"
11
12	"github.com/missingstudio/ai/gateway/core/chat"
13	"github.com/missingstudio/ai/gateway/core/provider"
14	"github.com/missingstudio/ai/gateway/internal/requester"
15)
16
17var OpenAIModels = []string{
18	"gpt-4-0125-preview",
19	"gpt-4-turbo-preview",
20	"gpt-4-1106-preview",
21	"gpt-4-vision-preview",
22	"gpt-4-1106-vision-preview",
23	"gpt-4",
24	"gpt-4-0613",
25	"gpt-4-32k",
26	"gpt-4-32k-0613",
27	"gpt-3.5-turbo-0125",
28	"gpt-3.5-turbo",
29	"gpt-3.5-turbo-1106",
30	"gpt-3.5-turbo-instruct",
31}
32
33func (oai *openAIProvider) ChatCompletions(ctx context.Context, payload *chat.ChatCompletionRequest) (*chat.ChatCompletionResponse, error) {
34	client := requester.NewHTTPClient()
35
36	payload.Stream = false
37	rawPayload, err := json.Marshal(payload)
38	if err != nil {
39		return nil, fmt.Errorf("unable to marshal openai chat request payload: %w", err)
40	}
41
42	requestURL := fmt.Sprintf("%s%s", oai.config.BaseURL, oai.config.ChatCompletions)
43	req, err := http.NewRequestWithContext(ctx, "POST", requestURL, bytes.NewReader(rawPayload))
44	if err != nil {
45		return nil, err
46	}
47
48	req = oai.AddDefaultHeaders(req, provider.AuthorizationHeader)
49	resp, err := client.SendRequestRaw(req)
50	if err != nil {
51		return nil, err
52	}
53
54	data := &chat.ChatCompletionResponse{}
55	if err := json.NewDecoder(resp.Body).Decode(data); err != nil {
56		return nil, err
57	}
58
59	return data, nil
60}
61
62func (*openAIProvider) Models() []string {
63	return OpenAIModels
64}
65
66func (oai *openAIProvider) AddDefaultHeaders(req *http.Request, key string) *http.Request {
67	providerConfigMap := oai.provider.GetConfig([]string{
68		provider.AuthorizationHeader,
69	})
70
71	var authorizationHeader string
72	if val, ok := providerConfigMap[provider.AuthorizationHeader].(string); ok && val != "" {
73		authorizationHeader = val
74	}
75
76	req.Header.Add("Content-Type", "application/json")
77	req.Header.Add("Authorization", authorizationHeader)
78	return req
79}
80
81func (oai *openAIProvider) StreamChatCompletions(ctx context.Context, payload *chat.ChatCompletionRequest, stream chan []byte) error {
82	client := requester.NewHTTPClient()
83
84	payload.Stream = true
85	rawPayload, err := json.Marshal(payload)
86	if err != nil {
87		return fmt.Errorf("unable to marshal openai chat request payload: %w", err)
88	}
89
90	requestURL := fmt.Sprintf("%s%s", oai.config.BaseURL, oai.config.ChatCompletions)
91	req, err := http.NewRequestWithContext(ctx, "POST", requestURL, bytes.NewReader(rawPayload))
92	if err != nil {
93		return err
94	}
95
96	req = oai.AddDefaultHeaders(req, provider.AuthorizationHeader)
97	resp, err := client.SendRequestRaw(req)
98	if err != nil {
99		return err
100	}
101	defer resp.Body.Close()
102
103	scanner := bufio.NewScanner(resp.Body)
104	for scanner.Scan() {
105		stream <- scanner.Bytes()
106
107		line := scanner.Text()
108		if strings.HasPrefix(line, "data:") {
109			event := strings.TrimPrefix(line, "data:")
110			event = strings.TrimSpace(strings.Trim(event, "\n"))
111			if strings.Contains(line, "[DONE]") {
112				break
113			}
114
115			var data chat.ChatCompletionResponse
116			if err := json.Unmarshal([]byte(event), &data); err != nil {
117				return err
118			}
119		}
120	}
121
122	close(stream)
123	return nil
gateway/internal/api/v1/models.go:0-46 
1package v1
2
3import (
4	"context"
5
6	"connectrpc.com/connect"
7	"github.com/missingstudio/ai/gateway/core/provider"
8	"github.com/missingstudio/ai/gateway/internal/provider/base"
9	llmv1 "github.com/missingstudio/protos/pkg/llm/v1"
10)
11
12func (s *V1Handler) ListModels(ctx context.Context, req *connect.Request[llmv1.ModelRequest]) (*connect.Response[llmv1.ModelResponse], error) {
13	allProviderModels := map[string]*llmv1.ProviderModels{}
14
15	for name := range base.ProviderRegistry {
16		// Check if the provider is healthy before fetching models
17		if !router.DefaultHealthChecker{}.IsHealthy(name) {
18			continue
19		}
20
21		provider, err := s.iProviderService.GetProvider(provider.Provider{Name: name})
22		if err != nil {
23			continue
24		}
25
26		providerInfo := provider.Info()
27		providerModels := provider.Models()
28		providerName := providerInfo.Name
29
30		var models []*llmv1.Model
31		for _, val := range providerModels {
32			models = append(models, &llmv1.Model{
33				Name:  val,
34				Value: val,
35			})
36		}
37
38		allProviderModels[name] = &llmv1.ProviderModels{
39			Name:   providerName,
40			Models: models,
41		}
42	}
43
44	return connect.NewResponse(&llmv1.ModelResponse{
45		Models: allProviderModels,
46	}), nil
playgrounds/apps/studio/app/(llm)/playground/hooks/useModelFetch.tsx:0-38 
1import { useEffect, useState } from "react";
2import { toast } from "sonner";
3
4interface ModelType {
5  name: string;
6  models: Model[];
7}
8
9interface Model {
10  name: string;
11  value: string;
12}
13
14const BASE_URL = process.env.NEXT_PUBLIC_GATEWAY_URL ?? "http://localhost:3000";
15export function useModelFetch() {
16  const [providers, setProviders] = useState<ModelType[]>([]);
17
18  useEffect(() => {
19    async function fetchModels() {
20      try {
21        const response = await fetch(`${BASE_URL}/api/v1/models`);
22        const { models } = await response.json();
23        const fetchedProviders: ModelType[] = Object.keys(models).map(
24          (key) => ({
25            name: models[key].name,
26            models: models[key].models,
27          })
28        );
29        setProviders(fetchedProviders);
30      } catch (e) {
31        toast.error("AI Studio is not running");
32      }
33    }
34
35    fetchModels();
36  }, []);
37
38  return { providers };
playgrounds/apps/studio/app/(llm)/playground/components/modelselector.tsx:0-73 
1"use client";
2import { CaretSortIcon } from "@radix-ui/react-icons";
3import { PopoverProps } from "@radix-ui/react-popover";
4import React from "react";
5
6import {
7  Popover,
8  PopoverContent,
9  PopoverTrigger,
10} from "@missingstudio/ui/popover";
11
12import { Button } from "@missingstudio/ui/button";
13import {
14  Command,
15  CommandEmpty,
16  CommandGroup,
17  CommandInput,
18  CommandList,
19} from "@missingstudio/ui/command";
20import { Label } from "@missingstudio/ui/label";
21import { ModelItem } from "~/app/(llm)/playground/components/modelitem";
22import { useModelFetch } from "~/app/(llm)/playground/hooks/useModelFetch";
23import { useStore } from "~/app/(llm)/playground/store";
24
25interface ModelSelectorProps extends PopoverProps {}
26
27export default function ModelSelector(props: ModelSelectorProps) {
28  const [open, setOpen] = React.useState(false);
29  const { providers } = useModelFetch();
30  const { model, setModel, setProvider } = useStore();
31
32  return (
33    <div className="flex items-center gap-2">
34      <Label htmlFor="model">Model: </Label>
35      <Popover open={open} onOpenChange={setOpen} {...props}>
36        <PopoverTrigger asChild>
37          <Button
38            variant="outline"
39            role="combobox"
40            aria-expanded={open}
41            aria-label="Select a model"
42            className="w-full justify-between"
43          >
44            {model ? model : "Select a model..."}
45            <CaretSortIcon className="ml-2 h-4 w-4 shrink-0 opacity-50" />
46          </Button>
47        </PopoverTrigger>
48        <PopoverContent align="end" className="w-[250px] p-0">
49          <Command loop>
50            <CommandList className="h-[var(--cmdk-list-height)] max-h-[400px]">
51              <CommandInput placeholder="Search Models..." />
52              <CommandEmpty>No Models found.</CommandEmpty>
53              {providers.map((provider) => (
54                <CommandGroup key={provider.name} heading={provider.name}>
55                  {provider.models.map((singleModel, index) => (
56                    <ModelItem
57                      key={`${index}_${provider.name}_${singleModel.value}`}
58                      id={`${singleModel.value}`}
59                      isSelected={model === singleModel.value}
60                      onSelect={() => {
61                        setProvider(provider.name);
62                        setModel(singleModel.value);
63                        setOpen(false);
64                      }}
65                    />
66                  ))}
67                </CommandGroup>
68              ))}
69            </CommandList>
70          </Command>
71        </PopoverContent>
72      </Popover>
73    </div>