カスタムモデルの追加

イントロダクション

ベンダー統合が完了した後、次にベンダーの下でモデルのインテグレーションを行います。ここでは、全体のプロセスを理解するために、例としてXinferenceを使用して、段階的にベンダーのインテグレーションを完了します。

注意が必要なのは、カスタムモデルの場合、各モデルのインテグレーションには完全なベンダークレデンシャルの記入が必要です。

事前定義モデルとは異なり、カスタムベンダーのインテグレーション時には常に以下の2つのパラメータが存在し、ベンダー yaml に定義する必要はありません。

前述したように、ベンダーはvalidate_provider_credentialを実装する必要はなく、Runtimeがユーザーが選択したモデルタイプとモデル名に基づいて、対応するモデル層のvalidate_credentialsを呼び出して検証を行います。

ベンダー yaml の作成

まず、インテグレーションを行うベンダーがどのタイプのモデルをサポートしているかを確認します。

現在サポートされているモデルタイプは以下の通りです:

  • llm テキスト生成モデル

  • text_embedding テキスト Embedding モデル

  • rerank Rerank モデル

  • speech2text 音声からテキスト変換

  • tts テキストから音声変換

  • moderation モデレーション

XinferenceLLMText EmbeddingRerankをサポートしているため、xinference.yamlを作成します。

その後、Xinferenceでモデルを定義するために必要なクレデンシャルを考えます。

  • 3つの異なるモデルをサポートするため、model_typeを使用してこのモデルのタイプを指定する必要があります。3つのタイプがあるので、次のように記述します。

  • 各モデルには独自の名称model_nameがあるため、ここで定義する必要があります。

  • Xinferenceのローカルデプロイのアドレスを記入します。

  • 各モデルには一意の model_uid があるため、ここで定義する必要があります。

これで、ベンダーの基本定義が完了しました。

モデルコードの作成

次に、llmタイプを例にとって、xinference.llm.llm.pyを作成します。

llm.py内で、Xinference LLM クラスを作成し、XinferenceAILargeLanguageModel(任意の名前)と名付けて、__base.large_language_model.LargeLanguageModel基底クラスを継承し、以下のメソッドを実装します:

  • LLM 呼び出し

    LLM 呼び出しのコアメソッドを実装し、ストリームレスポンスと同期レスポンスの両方をサポートします。

    実装時には、同期レスポンスとストリームレスポンスを処理するために2つの関数を使用してデータを返す必要があります。Pythonはyieldキーワードを含む関数をジェネレータ関数として認識し、返されるデータ型は固定でジェネレーターになります。そのため、同期レスポンスとストリームレスポンスは別々に実装する必要があります。以下のように実装します(例では簡略化されたパラメータを使用していますが、実際の実装では上記のパラメータリストに従って実装してください):

  • 予測トークン数の計算

    モデルが予測トークン数の計算インターフェースを提供していない場合、直接0を返すことができます。

    時には、直接0を返す必要がない場合もあります。その場合はself._get_num_tokens_by_gpt2(text: str)を使用して予測トークン数を取得することができます。このメソッドはAIModel基底クラスにあり、GPT2のTokenizerを使用して計算を行いますが、代替方法として使用されるものであり、完全に正確ではありません。

  • モデルクレデンシャル検証

    ベンダークレデンシャル検証と同様に、ここでは個々のモデルについて検証を行います。

  • モデルパラメータスキーマ

    カスタムタイプとは異なり、yamlファイルでモデルがサポートするパラメータを定義していないため、動的にモデルパラメータのスキーマを生成する必要があります。

    例えば、Xinferenceはmax_tokenstemperaturetop_pの3つのモデルパラメータをサポートしています。

    しかし、ベンダーによっては異なるモデルに対して異なるパラメータをサポートしている場合があります。例えば、ベンダーOpenLLMtop_kをサポートしていますが、全てのモデルがtop_kをサポートしているわけではありません。ここでは、例としてAモデルがtop_kをサポートし、Bモデルがtop_kをサポートしていない場合、以下のように動的にモデルパラメータのスキーマを生成します:

  • 呼び出しエラーマッピングテーブル

    モデル呼び出し時にエラーが発生した場合、Runtimeが指定するInvokeErrorタイプにマッピングする必要があります。これにより、Difyは異なるエラーに対して異なる後続処理を行うことができます。

    Runtime Errors:

    • InvokeConnectionError 呼び出し接続エラー

    • InvokeServerUnavailableError 呼び出しサービスが利用不可

    • InvokeRateLimitError 呼び出し回数制限に達した

    • InvokeAuthorizationError 認証エラー

    • InvokeBadRequestError 不正なリクエストパラメータ

インターフェース方法の詳細については:インターフェースをご覧ください。具体的な実装例については、llm.pyを参照してください。

Last updated